Merge remote-tracking branch 'origin/dev' into feature/2736-instantaneous-rewards
This commit is contained in:
@@ -10,12 +10,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Random Number Generator Seeding by specifying a random number seed in the config file.
|
||||
- Implemented Terminal service class, providing a generic terminal simulation.
|
||||
- Added `User`, `UserManager` and `UserSessionManager` to enable the creation of user accounts and login on Nodes.
|
||||
- Added actions to establish SSH connections, send commands remotely and terminate SSH connections.
|
||||
- Added actions to change users' passwords.
|
||||
- Added a `listen_on_ports` set in the `IOSoftware` class to enable software listening on ports in addition to the
|
||||
main port they're assigned.
|
||||
- Added reward calculation details to AgentHistoryItem.
|
||||
|
||||
### Changed
|
||||
- File and folder observations can now be configured to always show the true health status, or require scanning like before.
|
||||
- It's now possible to disable stickiness on reward components, meaning their value returns to 0 during timesteps where agent don't issue the corresponding action. Affects `GreenAdminDatabaseUnreachablePenalty`, `WebpageUnavailablePenalty`, `WebServer404Penalty`
|
||||
- Node observations can now be configured to show the number of active local and remote logins.
|
||||
|
||||
### Fixed
|
||||
- Folder observations showing the true health state without scanning (the old behaviour can be reenabled via config)
|
||||
|
||||
@@ -1071,6 +1071,83 @@ class NodeNetworkServiceReconAction(AbstractAction):
|
||||
]
|
||||
|
||||
|
||||
class NodeAccountsChangePasswordAction(AbstractAction):
|
||||
"""Action which changes the password for a user."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
|
||||
def form_request(self, node_id: str, username: str, current_password: str, new_password: str) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
node_name,
|
||||
"service",
|
||||
"UserManager",
|
||||
"change_password",
|
||||
username,
|
||||
current_password,
|
||||
new_password,
|
||||
]
|
||||
|
||||
|
||||
class NodeSessionsRemoteLoginAction(AbstractAction):
|
||||
"""Action which performs a remote session login."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
|
||||
def form_request(self, node_id: str, username: str, password: str, remote_ip: str) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
node_name,
|
||||
"service",
|
||||
"Terminal",
|
||||
"ssh_to_remote",
|
||||
username,
|
||||
password,
|
||||
remote_ip,
|
||||
]
|
||||
|
||||
|
||||
class NodeSessionsRemoteLogoutAction(AbstractAction):
|
||||
"""Action which performs a remote session logout."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
|
||||
def form_request(self, node_id: str, remote_ip: str) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
return ["network", "node", node_name, "service", "Terminal", "remote_logoff", remote_ip]
|
||||
|
||||
|
||||
class NodeSendRemoteCommandAction(AbstractAction):
|
||||
"""Action which sends a terminal command to a remote node via SSH."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
|
||||
def form_request(self, node_id: int, remote_ip: str, command: RequestFormat) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
node_name,
|
||||
"service",
|
||||
"Terminal",
|
||||
"send_remote_command",
|
||||
remote_ip,
|
||||
{"command": command},
|
||||
]
|
||||
|
||||
|
||||
class ActionManager:
|
||||
"""Class which manages the action space for an agent."""
|
||||
|
||||
@@ -1122,6 +1199,10 @@ class ActionManager:
|
||||
"CONFIGURE_DATABASE_CLIENT": ConfigureDatabaseClientAction,
|
||||
"CONFIGURE_RANSOMWARE_SCRIPT": ConfigureRansomwareScriptAction,
|
||||
"CONFIGURE_DOSBOT": ConfigureDoSBotAction,
|
||||
"NODE_ACCOUNTS_CHANGE_PASSWORD": NodeAccountsChangePasswordAction,
|
||||
"SSH_TO_REMOTE": NodeSessionsRemoteLoginAction,
|
||||
"SESSIONS_REMOTE_LOGOFF": NodeSessionsRemoteLogoutAction,
|
||||
"NODE_SEND_REMOTE_COMMAND": NodeSendRemoteCommandAction,
|
||||
}
|
||||
"""Dictionary which maps action type strings to the corresponding action class."""
|
||||
|
||||
|
||||
@@ -36,6 +36,8 @@ class AgentHistoryItem(BaseModel):
|
||||
|
||||
reward: Optional[float] = None
|
||||
|
||||
reward_info: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class AgentStartSettings(BaseModel):
|
||||
"""Configuration values for when an agent starts performing actions."""
|
||||
|
||||
@@ -10,6 +10,7 @@ from primaite import getLogger
|
||||
from primaite.game.agent.observations.acl_observation import ACLObservation
|
||||
from primaite.game.agent.observations.nic_observations import PortObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -32,6 +33,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
"""List of protocols for encoding ACLs."""
|
||||
num_rules: Optional[int] = None
|
||||
"""Number of rules ACL rules to show."""
|
||||
include_users: Optional[bool] = True
|
||||
"""If True, report user session information."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -41,6 +44,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
port_list: List[int],
|
||||
protocol_list: List[str],
|
||||
num_rules: int,
|
||||
include_users: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a firewall observation instance.
|
||||
@@ -58,9 +62,13 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
:type protocol_list: List[str]
|
||||
:param num_rules: Number of rules configured in the firewall.
|
||||
:type num_rules: int
|
||||
:param include_users: If True, report user session information.
|
||||
:type include_users: bool
|
||||
"""
|
||||
self.where: WhereType = where
|
||||
|
||||
self.include_users: bool = include_users
|
||||
self.max_users: int = 3
|
||||
"""Maximum number of remote sessions observable, excess sessions are truncated."""
|
||||
self.ports: List[PortObservation] = [
|
||||
PortObservation(where=self.where + ["NICs", port_num]) for port_num in (1, 2, 3)
|
||||
]
|
||||
@@ -142,6 +150,9 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
:return: Observation containing the status of ports and ACLs for internal, DMZ, and external traffic.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
firewall_state = access_from_nested_dict(state, self.where)
|
||||
if firewall_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
obs = {
|
||||
"PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)},
|
||||
"ACL": {
|
||||
@@ -159,6 +170,12 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
},
|
||||
},
|
||||
}
|
||||
if self.include_users:
|
||||
sess = firewall_state["services"]["UserSessionManager"]
|
||||
obs["users"] = {
|
||||
"local_login": 1 if sess["current_local_user"] else 0,
|
||||
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
|
||||
}
|
||||
return obs
|
||||
|
||||
@property
|
||||
@@ -218,4 +235,5 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
port_list=config.port_list,
|
||||
protocol_list=config.protocol_list,
|
||||
num_rules=config.num_rules,
|
||||
include_users=config.include_users,
|
||||
)
|
||||
|
||||
@@ -52,6 +52,8 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
"""
|
||||
If True, files and folders must be scanned to update the health state. If False, true state is always shown.
|
||||
"""
|
||||
include_users: Optional[bool] = True
|
||||
"""If True, report user session information."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -69,6 +71,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
monitored_traffic: Optional[Dict],
|
||||
include_num_access: bool,
|
||||
file_system_requires_scan: bool,
|
||||
include_users: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a host observation instance.
|
||||
@@ -103,10 +106,15 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
:param file_system_requires_scan: If True, the files and folders must be scanned to update the health state.
|
||||
If False, the true state is always shown.
|
||||
:type file_system_requires_scan: bool
|
||||
:param include_users: If True, report user session information.
|
||||
:type include_users: bool
|
||||
"""
|
||||
self.where: WhereType = where
|
||||
|
||||
self.include_num_access = include_num_access
|
||||
self.include_users = include_users
|
||||
self.max_users: int = 3
|
||||
"""Maximum number of remote sessions observable, excess sessions are truncated."""
|
||||
|
||||
# Ensure lists have lengths equal to specified counts by truncating or padding
|
||||
self.services: List[ServiceObservation] = services
|
||||
@@ -165,6 +173,8 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
if self.include_num_access:
|
||||
self.default_observation["num_file_creations"] = 0
|
||||
self.default_observation["num_file_deletions"] = 0
|
||||
if self.include_users:
|
||||
self.default_observation["users"] = {"local_login": 0, "remote_sessions": 0}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
@@ -192,6 +202,12 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
if self.include_num_access:
|
||||
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
|
||||
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
|
||||
if self.include_users:
|
||||
sess = node_state["services"]["UserSessionManager"]
|
||||
obs["users"] = {
|
||||
"local_login": 1 if sess["current_local_user"] else 0,
|
||||
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
|
||||
}
|
||||
return obs
|
||||
|
||||
@property
|
||||
@@ -216,6 +232,10 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
if self.include_num_access:
|
||||
shape["num_file_creations"] = spaces.Discrete(4)
|
||||
shape["num_file_deletions"] = spaces.Discrete(4)
|
||||
if self.include_users:
|
||||
shape["users"] = spaces.Dict(
|
||||
{"local_login": spaces.Discrete(2), "remote_sessions": spaces.Discrete(self.max_users + 1)}
|
||||
)
|
||||
return spaces.Dict(shape)
|
||||
|
||||
@classmethod
|
||||
@@ -273,4 +293,5 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
monitored_traffic=config.monitored_traffic,
|
||||
include_num_access=config.include_num_access,
|
||||
file_system_requires_scan=config.file_system_requires_scan,
|
||||
include_users=config.include_users,
|
||||
)
|
||||
|
||||
@@ -46,6 +46,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
"""Flag to include the number of accesses."""
|
||||
file_system_requires_scan: bool = True
|
||||
"""If True, the folder must be scanned to update the health state. Tf False, the true state is always shown."""
|
||||
include_users: Optional[bool] = True
|
||||
"""If True, report user session information."""
|
||||
num_ports: Optional[int] = None
|
||||
"""Number of ports."""
|
||||
ip_list: Optional[List[str]] = None
|
||||
@@ -191,6 +193,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
host_config.include_num_access = config.include_num_access
|
||||
if host_config.file_system_requires_scan is None:
|
||||
host_config.file_system_requires_scan = config.file_system_requires_scan
|
||||
if host_config.include_users is None:
|
||||
host_config.include_users = config.include_users
|
||||
|
||||
for router_config in config.routers:
|
||||
if router_config.num_ports is None:
|
||||
@@ -205,6 +209,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
router_config.protocol_list = config.protocol_list
|
||||
if router_config.num_rules is None:
|
||||
router_config.num_rules = config.num_rules
|
||||
if router_config.include_users is None:
|
||||
router_config.include_users = config.include_users
|
||||
|
||||
for firewall_config in config.firewalls:
|
||||
if firewall_config.ip_list is None:
|
||||
@@ -217,6 +223,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
firewall_config.protocol_list = config.protocol_list
|
||||
if firewall_config.num_rules is None:
|
||||
firewall_config.num_rules = config.num_rules
|
||||
if firewall_config.include_users is None:
|
||||
firewall_config.include_users = config.include_users
|
||||
|
||||
hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts]
|
||||
routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers]
|
||||
|
||||
@@ -39,6 +39,8 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
"""List of protocols for encoding ACLs."""
|
||||
num_rules: Optional[int] = None
|
||||
"""Number of rules ACL rules to show."""
|
||||
include_users: Optional[bool] = True
|
||||
"""If True, report user session information."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -46,6 +48,7 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
ports: List[PortObservation],
|
||||
num_ports: int,
|
||||
acl: ACLObservation,
|
||||
include_users: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a router observation instance.
|
||||
@@ -59,12 +62,16 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
:type num_ports: int
|
||||
:param acl: ACL observation representing the access control list of the router.
|
||||
:type acl: ACLObservation
|
||||
:param include_users: If True, report user session information.
|
||||
:type include_users: bool
|
||||
"""
|
||||
self.where: WhereType = where
|
||||
self.ports: List[PortObservation] = ports
|
||||
self.acl: ACLObservation = acl
|
||||
self.num_ports: int = num_ports
|
||||
|
||||
self.include_users: bool = include_users
|
||||
self.max_users: int = 3
|
||||
"""Maximum number of remote sessions observable, excess sessions are truncated."""
|
||||
while len(self.ports) < num_ports:
|
||||
self.ports.append(PortObservation(where=None))
|
||||
while len(self.ports) > num_ports:
|
||||
@@ -95,6 +102,12 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
obs["ACL"] = self.acl.observe(state)
|
||||
if self.ports:
|
||||
obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)}
|
||||
if self.include_users:
|
||||
sess = router_state["services"]["UserSessionManager"]
|
||||
obs["users"] = {
|
||||
"local_login": 1 if sess["current_local_user"] else 0,
|
||||
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
|
||||
}
|
||||
return obs
|
||||
|
||||
@property
|
||||
@@ -143,4 +156,4 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
|
||||
ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports]
|
||||
acl = ACLObservation.from_config(config=config.acl, parent_where=where)
|
||||
return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl)
|
||||
return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl, include_users=config.include_users)
|
||||
|
||||
@@ -47,7 +47,15 @@ class AbstractReward:
|
||||
|
||||
@abstractmethod
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state."""
|
||||
"""Calculate the reward for the current state.
|
||||
|
||||
:param state: Current simulation state
|
||||
:type state: Dict
|
||||
:param last_action_response: Current agent history state
|
||||
:type last_action_response: AgentHistoryItem state
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
@@ -67,7 +75,15 @@ class DummyReward(AbstractReward):
|
||||
"""Dummy reward function component which always returns 0."""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state."""
|
||||
"""Calculate the reward for the current state.
|
||||
|
||||
:param state: Current simulation state
|
||||
:type state: Dict
|
||||
:param last_action_response: Current agent history state
|
||||
:type last_action_response: AgentHistoryItem state
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
@@ -109,8 +125,12 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
:param state: Current simulation state
|
||||
:type state: Dict
|
||||
:param last_action_response: Current agent history state
|
||||
:type last_action_response: AgentHistoryItem state
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
database_file_state = access_from_nested_dict(state, self.location_in_state)
|
||||
if database_file_state is NOT_PRESENT_IN_STATE:
|
||||
@@ -322,6 +342,12 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
component will keep track of that information. In that case, it doesn't matter whether the last successful
|
||||
request returned was able to connect to the database server, because there has been an unsuccessful request
|
||||
since.
|
||||
:param state: Current simulation state
|
||||
:type state: Dict
|
||||
:param last_action_response: Current agent history state
|
||||
:type last_action_response: AgentHistoryItem state
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
request_attempted = last_action_response.request == [
|
||||
"network",
|
||||
@@ -333,10 +359,13 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
]
|
||||
|
||||
if request_attempted: # if agent makes request, always recalculate fresh value
|
||||
last_action_response.reward_info = {"connection_attempt_status": last_action_response.response.status}
|
||||
self.reward = 1.0 if last_action_response.response.status == "success" else -1.0
|
||||
elif not self.sticky: # if no new request and not sticky, set reward to 0
|
||||
last_action_response.reward_info = {"connection_attempt_status": "n/a"}
|
||||
self.reward = 0.0
|
||||
else: # if no new request and sticky, reuse reward value from last step
|
||||
last_action_response.reward_info = {"connection_attempt_status": "n/a"}
|
||||
pass
|
||||
|
||||
return self.reward
|
||||
@@ -384,7 +413,15 @@ class SharedReward(AbstractReward):
|
||||
"""Method that retrieves an agent's current reward given the agent's name."""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Simply access the other agent's reward and return it."""
|
||||
"""Simply access the other agent's reward and return it.
|
||||
|
||||
:param state: Current simulation state
|
||||
:type state: Dict
|
||||
:param last_action_response: Current agent history state
|
||||
:type last_action_response: AgentHistoryItem state
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
return self.callback(self.agent_name)
|
||||
|
||||
@classmethod
|
||||
@@ -417,7 +454,15 @@ class ActionPenalty(AbstractReward):
|
||||
self.do_nothing_penalty = do_nothing_penalty
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the penalty to be applied."""
|
||||
"""Calculate the penalty to be applied.
|
||||
|
||||
:param state: Current simulation state
|
||||
:type state: Dict
|
||||
:param last_action_response: Current agent history state
|
||||
:type last_action_response: AgentHistoryItem state
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
if last_action_response.action == "DONOTHING":
|
||||
return self.do_nothing_penalty
|
||||
else:
|
||||
|
||||
@@ -990,6 +990,7 @@ class UserManager(Service):
|
||||
if user and user.password == current_password:
|
||||
user.password = new_password
|
||||
self.sys_log.info(f"{self.name}: Password changed for {username}")
|
||||
self._user_session_manager._logout_user(user=user)
|
||||
return True
|
||||
self.sys_log.info(f"{self.name}: Password change failed for {username}")
|
||||
return False
|
||||
@@ -1027,6 +1028,10 @@ class UserManager(Service):
|
||||
self.sys_log.info(f"{self.name}: Failed to enable user: {username}")
|
||||
return False
|
||||
|
||||
@property
|
||||
def _user_session_manager(self) -> "UserSessionManager":
|
||||
return self.software_manager.software["UserSessionManager"] # noqa
|
||||
|
||||
|
||||
class UserSession(SimComponent):
|
||||
"""
|
||||
@@ -1260,7 +1265,8 @@ class UserSessionManager(Service):
|
||||
:return: A dictionary representing the current state.
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["active_remote_logins"] = len(self.remote_sessions)
|
||||
state["current_local_user"] = None if not self.local_session else self.local_session.user.username
|
||||
state["active_remote_sessions"] = list(self.remote_sessions.keys())
|
||||
return state
|
||||
|
||||
@property
|
||||
@@ -1435,6 +1441,19 @@ class UserSessionManager(Service):
|
||||
"""
|
||||
return self._logout(local=False, remote_session_id=remote_session_id)
|
||||
|
||||
def _logout_user(self, user: Union[str, User]) -> bool:
|
||||
"""End a user session by username or user object."""
|
||||
if isinstance(user, str):
|
||||
user = self._user_manager.users[user] # grab user object from username
|
||||
for sess_id, session in self.remote_sessions.items():
|
||||
if session.user is user:
|
||||
self._logout(local=False, remote_session_id=sess_id)
|
||||
return True
|
||||
if self.local_user_logged_in and self.local_session.user is user:
|
||||
self.local_logout()
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def local_user_logged_in(self) -> bool:
|
||||
"""
|
||||
|
||||
@@ -23,6 +23,8 @@ from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from primaite.simulator.system.services.service import Service, ServiceOperatingState
|
||||
|
||||
|
||||
# TODO 2824: Since remote terminal connections and remote user sessions are the same thing, we could refactor
|
||||
# the terminal to leverage the user session manager's list. This way we avoid potential bugs and code ducplication
|
||||
class TerminalClientConnection(BaseModel):
|
||||
"""
|
||||
TerminalClientConnection Class.
|
||||
@@ -92,7 +94,7 @@ class LocalTerminalConnection(TerminalClientConnection):
|
||||
if not self.is_active:
|
||||
self.parent_terminal.sys_log.warning("Connection inactive, cannot execute")
|
||||
return None
|
||||
return self.parent_terminal.execute(command, connection_id=self.connection_uuid)
|
||||
return self.parent_terminal.execute(command)
|
||||
|
||||
|
||||
class RemoteTerminalConnection(TerminalClientConnection):
|
||||
@@ -162,22 +164,6 @@ class Terminal(Service):
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""Initialise Request manager."""
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request(
|
||||
"send",
|
||||
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.send())),
|
||||
)
|
||||
|
||||
def _login(request: RequestFormat, context: Dict) -> RequestResponse:
|
||||
login = self._process_local_login(username=request[0], password=request[1])
|
||||
if login:
|
||||
return RequestResponse(
|
||||
status="success",
|
||||
data={
|
||||
"ip_address": login.ip_address,
|
||||
},
|
||||
)
|
||||
else:
|
||||
return RequestResponse(status="failure", data={"reason": "Invalid login credentials"})
|
||||
|
||||
def _remote_login(request: RequestFormat, context: Dict) -> RequestResponse:
|
||||
login = self._send_remote_login(username=request[0], password=request[1], ip_address=request[2])
|
||||
@@ -191,10 +177,34 @@ class Terminal(Service):
|
||||
else:
|
||||
return RequestResponse(status="failure", data={})
|
||||
|
||||
rm.add_request(
|
||||
"ssh_to_remote",
|
||||
request_type=RequestType(func=_remote_login),
|
||||
)
|
||||
|
||||
def _remote_logoff(request: RequestFormat, context: Dict) -> RequestResponse:
|
||||
"""Logoff from remote connection."""
|
||||
ip_address = IPv4Address(request[0])
|
||||
remote_connection = self._get_connection_from_ip(ip_address=ip_address)
|
||||
if remote_connection:
|
||||
outcome = self._disconnect(remote_connection.connection_uuid)
|
||||
if outcome:
|
||||
return RequestResponse(
|
||||
status="success",
|
||||
data={},
|
||||
)
|
||||
else:
|
||||
return RequestResponse(
|
||||
status="failure",
|
||||
data={"reason": "No remote connection held."},
|
||||
)
|
||||
|
||||
rm.add_request("remote_logoff", request_type=RequestType(func=_remote_logoff))
|
||||
|
||||
def remote_execute_request(request: RequestFormat, context: Dict) -> RequestResponse:
|
||||
"""Execute an instruction."""
|
||||
command: str = request[0]
|
||||
ip_address: IPv4Address = IPv4Address(request[1])
|
||||
ip_address: IPv4Address = IPv4Address(request[0])
|
||||
command: str = request[1]["command"]
|
||||
remote_connection = self._get_connection_from_ip(ip_address=ip_address)
|
||||
if remote_connection:
|
||||
outcome = remote_connection.execute(command)
|
||||
@@ -209,30 +219,11 @@ class Terminal(Service):
|
||||
data={},
|
||||
)
|
||||
|
||||
def _logoff(request: RequestFormat, context: Dict) -> RequestResponse:
|
||||
"""Logoff from connection."""
|
||||
connection_uuid = request[0]
|
||||
self.parent.user_session_manager.local_logout(connection_uuid)
|
||||
self._disconnect(connection_uuid)
|
||||
return RequestResponse(status="success", data={})
|
||||
|
||||
rm.add_request(
|
||||
"Login",
|
||||
request_type=RequestType(func=_login),
|
||||
)
|
||||
|
||||
rm.add_request(
|
||||
"Remote Login",
|
||||
request_type=RequestType(func=_remote_login),
|
||||
)
|
||||
|
||||
rm.add_request(
|
||||
"Execute",
|
||||
"send_remote_command",
|
||||
request_type=RequestType(func=remote_execute_request),
|
||||
)
|
||||
|
||||
rm.add_request("Logoff", request_type=RequestType(func=_logoff))
|
||||
|
||||
return rm
|
||||
|
||||
def execute(self, command: List[Any]) -> Optional[RequestResponse]:
|
||||
@@ -280,13 +271,9 @@ class Terminal(Service):
|
||||
if self.operating_state != ServiceOperatingState.RUNNING:
|
||||
self.sys_log.warning(f"{self.name}: Cannot login as service is not running.")
|
||||
return None
|
||||
connection_request_id = str(uuid4())
|
||||
self._client_connection_requests[connection_request_id] = None
|
||||
if ip_address:
|
||||
# Assuming that if IP is passed we are connecting to remote
|
||||
return self._send_remote_login(
|
||||
username=username, password=password, ip_address=ip_address, connection_request_id=connection_request_id
|
||||
)
|
||||
return self._send_remote_login(username=username, password=password, ip_address=ip_address)
|
||||
else:
|
||||
return self._process_local_login(username=username, password=password)
|
||||
|
||||
@@ -313,6 +300,9 @@ class Terminal(Service):
|
||||
|
||||
def _check_client_connection(self, connection_id: str) -> bool:
|
||||
"""Check that client_connection_id is valid."""
|
||||
if not self.parent.user_session_manager.validate_remote_session_uuid(connection_id):
|
||||
self._disconnect(connection_id)
|
||||
return False
|
||||
return connection_id in self._connections
|
||||
|
||||
def _send_remote_login(
|
||||
@@ -320,32 +310,24 @@ class Terminal(Service):
|
||||
username: str,
|
||||
password: str,
|
||||
ip_address: IPv4Address,
|
||||
connection_request_id: str,
|
||||
connection_request_id: Optional[str] = None,
|
||||
is_reattempt: bool = False,
|
||||
) -> Optional[RemoteTerminalConnection]:
|
||||
"""Send a remote login attempt and connect to Node.
|
||||
|
||||
:param: username: Username used to connect to the remote node.
|
||||
:type: username: str
|
||||
|
||||
:param: password: Password used to connect to the remote node
|
||||
:type: password: str
|
||||
|
||||
:param: ip_address: Target Node IP address for login attempt.
|
||||
:type: ip_address: IPv4Address
|
||||
|
||||
:param: connection_request_id: Connection Request ID
|
||||
:type: connection_request_id: str
|
||||
|
||||
:param: connection_request_id: Connection Request ID, if not provided, a new one is generated
|
||||
:type: connection_request_id: Optional[str]
|
||||
:param: is_reattempt: True if the request has been reattempted. Default False.
|
||||
:type: is_reattempt: Optional[bool]
|
||||
|
||||
:return: RemoteTerminalConnection: Connection Object for sending further commands if successful, else False.
|
||||
|
||||
"""
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}"
|
||||
)
|
||||
connection_request_id = connection_request_id or str(uuid4())
|
||||
if is_reattempt:
|
||||
valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id)
|
||||
if valid_connection_request:
|
||||
@@ -360,6 +342,9 @@ class Terminal(Service):
|
||||
self.sys_log.warning(f"{self.name}: Remote connection to {ip_address} declined.")
|
||||
return None
|
||||
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}"
|
||||
)
|
||||
transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST
|
||||
connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA
|
||||
user_details: SSHUserCredentials = SSHUserCredentials(username=username, password=password)
|
||||
|
||||
942
tests/assets/configs/data_manipulation.yaml
Normal file
942
tests/assets/configs/data_manipulation.yaml
Normal file
@@ -0,0 +1,942 @@
|
||||
io_settings:
|
||||
save_agent_actions: true
|
||||
save_step_metadata: false
|
||||
save_pcap_logs: false
|
||||
save_sys_logs: false
|
||||
sys_log_level: WARNING
|
||||
|
||||
|
||||
game:
|
||||
max_episode_length: 128
|
||||
ports:
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocols:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
thresholds:
|
||||
nmne:
|
||||
high: 10
|
||||
medium: 5
|
||||
low: 0
|
||||
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
agent_settings:
|
||||
action_probabilities:
|
||||
0: 0.3
|
||||
1: 0.6
|
||||
2: 0.1
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_2
|
||||
applications:
|
||||
- application_name: WebBrowser
|
||||
- application_name: DatabaseClient
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_applications_per_node: 2
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
2:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.25
|
||||
options:
|
||||
node_hostname: client_2
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.05
|
||||
options:
|
||||
node_hostname: client_2
|
||||
|
||||
- ref: client_1_green_user
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
agent_settings:
|
||||
action_probabilities:
|
||||
0: 0.3
|
||||
1: 0.6
|
||||
2: 0.1
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_1
|
||||
applications:
|
||||
- application_name: WebBrowser
|
||||
- application_name: DatabaseClient
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_applications_per_node: 2
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
2:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.25
|
||||
options:
|
||||
node_hostname: client_1
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.05
|
||||
options:
|
||||
node_hostname: client_1
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
- ref: data_manipulation_attacker
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_1
|
||||
applications:
|
||||
- application_name: DataManipulationBot
|
||||
- node_name: client_2
|
||||
applications:
|
||||
- application_name: DataManipulationBot
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
|
||||
start_settings:
|
||||
start_step: 25
|
||||
frequency: 20
|
||||
variance: 5
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: CUSTOM
|
||||
options:
|
||||
components:
|
||||
- type: NODES
|
||||
label: NODES
|
||||
options:
|
||||
hosts:
|
||||
- hostname: domain_controller
|
||||
- hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- hostname: backup_server
|
||||
- hostname: security_suite
|
||||
- hostname: client_1
|
||||
- hostname: client_2
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
monitored_traffic:
|
||||
icmp:
|
||||
- NONE
|
||||
tcp:
|
||||
- DNS
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
num_rules: 10
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- router_1:eth-1<->switch_1:eth-8
|
||||
- router_1:eth-2<->switch_2:eth-8
|
||||
- switch_1:eth-1<->domain_controller:eth-1
|
||||
- switch_1:eth-2<->web_server:eth-1
|
||||
- switch_1:eth-3<->database_server:eth-1
|
||||
- switch_1:eth-4<->backup_server:eth-1
|
||||
- switch_1:eth-7<->security_suite:eth-1
|
||||
- switch_2:eth-1<->client_1:eth-1
|
||||
- switch_2:eth-2<->client_2:eth-1
|
||||
- switch_2:eth-7<->security_suite:eth-2
|
||||
- type: "NONE"
|
||||
label: ICS
|
||||
options: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_SERVICE_SCAN
|
||||
- type: NODE_SERVICE_STOP
|
||||
- type: NODE_SERVICE_START
|
||||
- type: NODE_SERVICE_PAUSE
|
||||
- type: NODE_SERVICE_RESUME
|
||||
- type: NODE_SERVICE_RESTART
|
||||
- type: NODE_SERVICE_DISABLE
|
||||
- type: NODE_SERVICE_ENABLE
|
||||
- type: NODE_SERVICE_FIX
|
||||
- type: NODE_FILE_SCAN
|
||||
- type: NODE_FILE_CHECKHASH
|
||||
- type: NODE_FILE_DELETE
|
||||
- type: NODE_FILE_REPAIR
|
||||
- type: NODE_FILE_RESTORE
|
||||
- type: NODE_FOLDER_SCAN
|
||||
- type: NODE_FOLDER_CHECKHASH
|
||||
- type: NODE_FOLDER_REPAIR
|
||||
- type: NODE_FOLDER_RESTORE
|
||||
- type: NODE_OS_SCAN
|
||||
- type: NODE_SHUTDOWN
|
||||
- type: NODE_STARTUP
|
||||
- type: NODE_RESET
|
||||
- type: ROUTER_ACL_ADDRULE
|
||||
- type: ROUTER_ACL_REMOVERULE
|
||||
- type: HOST_NIC_ENABLE
|
||||
- type: HOST_NIC_DISABLE
|
||||
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
# scan webapp service
|
||||
1:
|
||||
action: NODE_SERVICE_SCAN
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
# stop webapp service
|
||||
2:
|
||||
action: NODE_SERVICE_STOP
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
# start webapp service
|
||||
3:
|
||||
action: "NODE_SERVICE_START"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
4:
|
||||
action: "NODE_SERVICE_PAUSE"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
5:
|
||||
action: "NODE_SERVICE_RESUME"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
6:
|
||||
action: "NODE_SERVICE_RESTART"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
7:
|
||||
action: "NODE_SERVICE_DISABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
8:
|
||||
action: "NODE_SERVICE_ENABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
9: # check database.db file
|
||||
action: "NODE_FILE_SCAN"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
10:
|
||||
action: "NODE_FILE_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context.
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
11:
|
||||
action: "NODE_FILE_DELETE"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
12:
|
||||
action: "NODE_FILE_REPAIR"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
13:
|
||||
action: "NODE_SERVICE_FIX"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 0
|
||||
14:
|
||||
action: "NODE_FOLDER_SCAN"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
15:
|
||||
action: "NODE_FOLDER_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context.
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
16:
|
||||
action: "NODE_FOLDER_REPAIR"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
17:
|
||||
action: "NODE_FOLDER_RESTORE"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
18:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 0
|
||||
19:
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 0
|
||||
20:
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 0
|
||||
21:
|
||||
action: NODE_RESET
|
||||
options:
|
||||
node_id: 0
|
||||
22:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 1
|
||||
23:
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 1
|
||||
24:
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 1
|
||||
25:
|
||||
action: NODE_RESET
|
||||
options:
|
||||
node_id: 1
|
||||
26: # old action num: 18
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 2
|
||||
27:
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 2
|
||||
28:
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 2
|
||||
29:
|
||||
action: NODE_RESET
|
||||
options:
|
||||
node_id: 2
|
||||
30:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 3
|
||||
31:
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 3
|
||||
32:
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 3
|
||||
33:
|
||||
action: NODE_RESET
|
||||
options:
|
||||
node_id: 3
|
||||
34:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 4
|
||||
35:
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 4
|
||||
36:
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 4
|
||||
37:
|
||||
action: NODE_RESET
|
||||
options:
|
||||
node_id: 4
|
||||
38:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 5
|
||||
39: # old action num: 19 # shutdown client 1
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 5
|
||||
40: # old action num: 20
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 5
|
||||
41: # old action num: 21
|
||||
action: NODE_RESET
|
||||
options:
|
||||
node_id: 5
|
||||
42:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 6
|
||||
43:
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 6
|
||||
44:
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 6
|
||||
45:
|
||||
action: NODE_RESET
|
||||
options:
|
||||
node_id: 6
|
||||
|
||||
46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1"
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 1 # ALL
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2"
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 2
|
||||
permission: 2
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 1 # ALL
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
48: # old action num: 24 # block tcp traffic from client 1 to web app
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 3
|
||||
permission: 2
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 3 # web server
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
49: # old action num: 25 # block tcp traffic from client 2 to web app
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 4
|
||||
permission: 2
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 3 # web server
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
50: # old action num: 26
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 5
|
||||
permission: 2
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 4 # database
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
51: # old action num: 27
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 6
|
||||
permission: 2
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 4 # database
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
52: # old action num: 28
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 0
|
||||
53: # old action num: 29
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 1
|
||||
54: # old action num: 30
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 2
|
||||
55: # old action num: 31
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 3
|
||||
56: # old action num: 32
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 4
|
||||
57: # old action num: 33
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 5
|
||||
58: # old action num: 34
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 6
|
||||
59: # old action num: 35
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 7
|
||||
60: # old action num: 36
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 8
|
||||
61: # old action num: 37
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 9
|
||||
62: # old action num: 38
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
63: # old action num: 39
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
64: # old action num: 40
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
65: # old action num: 41
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
66: # old action num: 42
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 0
|
||||
67: # old action num: 43
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 0
|
||||
68: # old action num: 44
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 0
|
||||
69: # old action num: 45
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 0
|
||||
70: # old action num: 46
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 0
|
||||
71: # old action num: 47
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 0
|
||||
72: # old action num: 48
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
73: # old action num: 49
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
74: # old action num: 50
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 0
|
||||
75: # old action num: 51
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 0
|
||||
76: # old action num: 52
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 0
|
||||
77: # old action num: 53
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 0
|
||||
|
||||
|
||||
|
||||
options:
|
||||
nodes:
|
||||
- node_name: domain_controller
|
||||
- node_name: web_server
|
||||
applications:
|
||||
- application_name: DatabaseClient
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- node_name: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
services:
|
||||
- service_name: DatabaseService
|
||||
- node_name: backup_server
|
||||
- node_name: security_suite
|
||||
- node_name: client_1
|
||||
- node_name: client_2
|
||||
|
||||
max_folders_per_node: 2
|
||||
max_files_per_folder: 2
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DATABASE_FILE_INTEGRITY
|
||||
weight: 0.40
|
||||
options:
|
||||
node_hostname: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
- type: SHARED_REWARD
|
||||
weight: 1.0
|
||||
options:
|
||||
agent_name: client_1_green_user
|
||||
|
||||
- type: SHARED_REWARD
|
||||
weight: 1.0
|
||||
options:
|
||||
agent_name: client_2_green_user
|
||||
|
||||
agent_settings:
|
||||
flatten_obs: true
|
||||
action_masking: true
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nmne_config:
|
||||
capture_nmne: true
|
||||
nmne_capture_keywords:
|
||||
- DELETE
|
||||
nodes:
|
||||
|
||||
- hostname: router_1
|
||||
type: router
|
||||
num_ports: 5
|
||||
ports:
|
||||
1:
|
||||
ip_address: 192.168.1.1
|
||||
subnet_mask: 255.255.255.0
|
||||
2:
|
||||
ip_address: 192.168.10.1
|
||||
subnet_mask: 255.255.255.0
|
||||
acl:
|
||||
18:
|
||||
action: PERMIT
|
||||
src_port: POSTGRES_SERVER
|
||||
dst_port: POSTGRES_SERVER
|
||||
19:
|
||||
action: PERMIT
|
||||
src_port: DNS
|
||||
dst_port: DNS
|
||||
20:
|
||||
action: PERMIT
|
||||
src_port: FTP
|
||||
dst_port: FTP
|
||||
21:
|
||||
action: PERMIT
|
||||
src_port: HTTP
|
||||
dst_port: HTTP
|
||||
22:
|
||||
action: PERMIT
|
||||
src_port: ARP
|
||||
dst_port: ARP
|
||||
23:
|
||||
action: PERMIT
|
||||
protocol: ICMP
|
||||
|
||||
- hostname: switch_1
|
||||
type: switch
|
||||
num_ports: 8
|
||||
|
||||
- hostname: switch_2
|
||||
type: switch
|
||||
num_ports: 8
|
||||
|
||||
- hostname: domain_controller
|
||||
type: server
|
||||
ip_address: 192.168.1.10
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
services:
|
||||
- type: DNSServer
|
||||
options:
|
||||
domain_mapping:
|
||||
arcd.com: 192.168.1.12 # web server
|
||||
|
||||
- hostname: web_server
|
||||
type: server
|
||||
ip_address: 192.168.1.12
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- type: WebServer
|
||||
applications:
|
||||
- type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
|
||||
|
||||
- hostname: database_server
|
||||
type: server
|
||||
ip_address: 192.168.1.14
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- type: DatabaseService
|
||||
options:
|
||||
backup_server_ip: 192.168.1.16
|
||||
- type: FTPClient
|
||||
|
||||
- hostname: backup_server
|
||||
type: server
|
||||
ip_address: 192.168.1.16
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- type: FTPServer
|
||||
|
||||
- hostname: security_suite
|
||||
type: server
|
||||
ip_address: 192.168.1.110
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
network_interfaces:
|
||||
2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot
|
||||
ip_address: 192.168.10.110
|
||||
subnet_mask: 255.255.255.0
|
||||
|
||||
- hostname: client_1
|
||||
type: computer
|
||||
ip_address: 192.168.10.21
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
applications:
|
||||
- type: DataManipulationBot
|
||||
options:
|
||||
port_scan_p_of_success: 0.8
|
||||
data_manipulation_p_of_success: 0.8
|
||||
payload: "DELETE"
|
||||
server_ip: 192.168.1.14
|
||||
- type: WebBrowser
|
||||
options:
|
||||
target_url: http://arcd.com/users/
|
||||
- type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
services:
|
||||
- type: DNSClient
|
||||
|
||||
- hostname: client_2
|
||||
type: computer
|
||||
ip_address: 192.168.10.22
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
applications:
|
||||
- type: WebBrowser
|
||||
options:
|
||||
target_url: http://arcd.com/users/
|
||||
- type: DataManipulationBot
|
||||
options:
|
||||
port_scan_p_of_success: 0.8
|
||||
data_manipulation_p_of_success: 0.8
|
||||
payload: "DELETE"
|
||||
server_ip: 192.168.1.14
|
||||
- type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
services:
|
||||
- type: DNSClient
|
||||
|
||||
links:
|
||||
- endpoint_a_hostname: router_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 8
|
||||
- endpoint_a_hostname: router_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_hostname: switch_2
|
||||
endpoint_b_port: 8
|
||||
- endpoint_a_hostname: switch_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: domain_controller
|
||||
endpoint_b_port: 1
|
||||
- endpoint_a_hostname: switch_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_hostname: web_server
|
||||
endpoint_b_port: 1
|
||||
- endpoint_a_hostname: switch_1
|
||||
endpoint_a_port: 3
|
||||
endpoint_b_hostname: database_server
|
||||
endpoint_b_port: 1
|
||||
- endpoint_a_hostname: switch_1
|
||||
endpoint_a_port: 4
|
||||
endpoint_b_hostname: backup_server
|
||||
endpoint_b_port: 1
|
||||
- endpoint_a_hostname: switch_1
|
||||
endpoint_a_port: 7
|
||||
endpoint_b_hostname: security_suite
|
||||
endpoint_b_port: 1
|
||||
- endpoint_a_hostname: switch_2
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: client_1
|
||||
endpoint_b_port: 1
|
||||
- endpoint_a_hostname: switch_2
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_hostname: client_2
|
||||
endpoint_b_port: 1
|
||||
- endpoint_a_hostname: switch_2
|
||||
endpoint_a_port: 7
|
||||
endpoint_b_hostname: security_suite
|
||||
endpoint_b_port: 2
|
||||
@@ -458,6 +458,10 @@ def game_and_agent():
|
||||
{"type": "HOST_NIC_DISABLE"},
|
||||
{"type": "NETWORK_PORT_ENABLE"},
|
||||
{"type": "NETWORK_PORT_DISABLE"},
|
||||
{"type": "NODE_ACCOUNTS_CHANGE_PASSWORD"},
|
||||
{"type": "SSH_TO_REMOTE"},
|
||||
{"type": "SESSIONS_REMOTE_LOGOFF"},
|
||||
{"type": "NODE_SEND_REMOTE_COMMAND"},
|
||||
]
|
||||
|
||||
action_space = ActionManager(
|
||||
|
||||
@@ -0,0 +1,166 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.simulator.network.hardware.base import UserManager
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.services.service import ServiceOperatingState
|
||||
from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def game_and_agent_fixture(game_and_agent):
|
||||
"""Create a game with a simple agent that can be controlled by the tests."""
|
||||
game, agent = game_and_agent
|
||||
|
||||
router = game.simulation.network.get_node_by_hostname("router")
|
||||
router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=4)
|
||||
|
||||
return (game, agent)
|
||||
|
||||
|
||||
def test_remote_login(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
|
||||
game, agent = game_and_agent_fixture
|
||||
|
||||
server_1: Server = game.simulation.network.get_node_by_hostname("server_1")
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
# create a new user account on server_1 that will be logged into remotely
|
||||
server_1_usm: UserManager = server_1.software_manager.software["UserManager"]
|
||||
server_1_usm.add_user("user123", "password", is_admin=True)
|
||||
|
||||
action = (
|
||||
"SSH_TO_REMOTE",
|
||||
{
|
||||
"node_id": 0,
|
||||
"username": "user123",
|
||||
"password": "password",
|
||||
"remote_ip": str(server_1.network_interface[1].ip_address),
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
assert agent.history[-1].response.status == "success"
|
||||
|
||||
connection_established = False
|
||||
for conn_str, conn_obj in client_1.terminal.connections.items():
|
||||
conn_obj: RemoteTerminalConnection
|
||||
if conn_obj.ip_address == server_1.network_interface[1].ip_address:
|
||||
connection_established = True
|
||||
if not connection_established:
|
||||
pytest.fail("Remote SSH connection could not be established")
|
||||
|
||||
|
||||
def test_remote_login_wrong_password(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
|
||||
game, agent = game_and_agent_fixture
|
||||
|
||||
server_1: Server = game.simulation.network.get_node_by_hostname("server_1")
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
# create a new user account on server_1 that will be logged into remotely
|
||||
server_1_usm: UserManager = server_1.software_manager.software["UserManager"]
|
||||
server_1_usm.add_user("user123", "password", is_admin=True)
|
||||
|
||||
action = (
|
||||
"SSH_TO_REMOTE",
|
||||
{
|
||||
"node_id": 0,
|
||||
"username": "user123",
|
||||
"password": "wrong_password",
|
||||
"remote_ip": str(server_1.network_interface[1].ip_address),
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
assert agent.history[-1].response.status == "failure"
|
||||
|
||||
connection_established = False
|
||||
for conn_str, conn_obj in client_1.terminal.connections.items():
|
||||
conn_obj: RemoteTerminalConnection
|
||||
if conn_obj.ip_address == server_1.network_interface[1].ip_address:
|
||||
connection_established = True
|
||||
if connection_established:
|
||||
pytest.fail("Remote SSH connection was established despite wrong password")
|
||||
|
||||
|
||||
def test_remote_login_change_password(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
|
||||
game, agent = game_and_agent_fixture
|
||||
|
||||
server_1: Server = game.simulation.network.get_node_by_hostname("server_1")
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
# create a new user account on server_1 that will be logged into remotely
|
||||
server_1_um: UserManager = server_1.software_manager.software["UserManager"]
|
||||
server_1_um.add_user("user123", "password", is_admin=True)
|
||||
|
||||
action = (
|
||||
"NODE_ACCOUNTS_CHANGE_PASSWORD",
|
||||
{
|
||||
"node_id": 1, # server_1
|
||||
"username": "user123",
|
||||
"current_password": "password",
|
||||
"new_password": "different_password",
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
assert agent.history[-1].response.status == "success"
|
||||
assert server_1_um.users["user123"].password == "different_password"
|
||||
|
||||
|
||||
def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
|
||||
game, agent = game_and_agent_fixture
|
||||
|
||||
server_1: Server = game.simulation.network.get_node_by_hostname("server_1")
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
# create a new user account on server_1 that will be logged into remotely
|
||||
server_1_usm: UserManager = server_1.software_manager.software["UserManager"]
|
||||
server_1_usm.add_user("user123", "password", is_admin=True)
|
||||
|
||||
# Log in remotely
|
||||
action = (
|
||||
"SSH_TO_REMOTE",
|
||||
{
|
||||
"node_id": 0,
|
||||
"username": "user123",
|
||||
"password": "password",
|
||||
"remote_ip": str(server_1.network_interface[1].ip_address),
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
# Change password
|
||||
action = (
|
||||
"NODE_ACCOUNTS_CHANGE_PASSWORD",
|
||||
{
|
||||
"node_id": 1, # server_1
|
||||
"username": "user123",
|
||||
"current_password": "password",
|
||||
"new_password": "different_password",
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
# Assert that the user cannot execute an action
|
||||
action = (
|
||||
"NODE_SEND_REMOTE_COMMAND",
|
||||
{
|
||||
"node_id": 0,
|
||||
"remote_ip": str(server_1.network_interface[1].ip_address),
|
||||
"command": ["file_system", "create", "file", "folder123", "doggo.pdf", False],
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
assert server_1.file_system.get_folder("folder123") is None
|
||||
assert server_1.file_system.get_file("folder123", "doggo.pdf") is None
|
||||
@@ -33,6 +33,7 @@ def test_firewall_observation():
|
||||
wildcard_list=["0.0.0.255", "0.0.0.1"],
|
||||
port_list=["HTTP", "DNS"],
|
||||
protocol_list=["TCP"],
|
||||
include_users=False,
|
||||
)
|
||||
|
||||
observation = firewall_observation.observe(firewall.describe_state())
|
||||
|
||||
@@ -39,6 +39,7 @@ def test_host_observation(simulation):
|
||||
folders=[],
|
||||
network_interfaces=[],
|
||||
file_system_requires_scan=True,
|
||||
include_users=False,
|
||||
)
|
||||
|
||||
assert host_obs.space["operating_status"] == spaces.Discrete(5)
|
||||
|
||||
@@ -27,7 +27,7 @@ def test_router_observation():
|
||||
port_list=["HTTP", "DNS"],
|
||||
protocol_list=["TCP"],
|
||||
)
|
||||
router_observation = RouterObservation(where=[], ports=ports, num_ports=8, acl=acl)
|
||||
router_observation = RouterObservation(where=[], ports=ports, num_ports=8, acl=acl, include_users=False)
|
||||
|
||||
# Observe the state using the RouterObservation instance
|
||||
observed_output = router_observation.observe(router.describe_state())
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import pytest
|
||||
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yaml"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def env_with_ssh() -> PrimaiteGymEnv:
|
||||
"""Build data manipulation environment with SSH port open on router."""
|
||||
env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG)
|
||||
env.agent.flatten_obs = False
|
||||
router: Router = env.game.simulation.network.get_node_by_hostname("router_1")
|
||||
router.acl.add_rule(ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=3)
|
||||
return env
|
||||
|
||||
|
||||
def extract_login_numbers_from_obs(obs):
|
||||
"""Traverse the observation dictionary and return number of user sessions for all nodes."""
|
||||
login_nums = {}
|
||||
for node_name, node_obs in obs["NODES"].items():
|
||||
login_nums[node_name] = node_obs.get("users")
|
||||
return login_nums
|
||||
|
||||
|
||||
class TestUserObservations:
|
||||
"""Test that the RouterObservation, FirewallObservation, and HostObservation have the correct number of logins."""
|
||||
|
||||
def test_no_sessions_at_episode_start(self, env_with_ssh):
|
||||
"""Test that all of the login observations start at 0 before any logins occur."""
|
||||
obs, *_ = env_with_ssh.step(0)
|
||||
logins_obs = extract_login_numbers_from_obs(obs)
|
||||
for o in logins_obs.values():
|
||||
assert o["local_login"] == 0
|
||||
assert o["remote_sessions"] == 0
|
||||
|
||||
def test_single_login(self, env_with_ssh: PrimaiteGymEnv):
|
||||
"""Test that performing a remote login increases the remote_sessions observation by 1."""
|
||||
client_1 = env_with_ssh.game.simulation.network.get_node_by_hostname("client_1")
|
||||
client_1.terminal._send_remote_login("admin", "admin", "192.168.1.14") # connect to database server via ssh
|
||||
obs, *_ = env_with_ssh.step(0)
|
||||
logins_obs = extract_login_numbers_from_obs(obs)
|
||||
db_srv_logins_obs = logins_obs.pop("HOST2") # this is the index of db server
|
||||
assert db_srv_logins_obs["local_login"] == 0
|
||||
assert db_srv_logins_obs["remote_sessions"] == 1
|
||||
for o in logins_obs.values(): # the remaining obs after popping HOST2
|
||||
assert o["local_login"] == 0
|
||||
assert o["remote_sessions"] == 0
|
||||
|
||||
def test_logout(self, env_with_ssh: PrimaiteGymEnv):
|
||||
"""Test that remote_sessions observation correctly decreases upon logout."""
|
||||
client_1 = env_with_ssh.game.simulation.network.get_node_by_hostname("client_1")
|
||||
client_1.terminal._send_remote_login("admin", "admin", "192.168.1.14") # connect to database server via ssh
|
||||
db_srv = env_with_ssh.game.simulation.network.get_node_by_hostname("database_server")
|
||||
db_srv.user_manager.change_user_password("admin", "admin", "different_pass") # changing password logs out user
|
||||
|
||||
obs, *_ = env_with_ssh.step(0)
|
||||
logins_obs = extract_login_numbers_from_obs(obs)
|
||||
for o in logins_obs.values():
|
||||
assert o["local_login"] == 0
|
||||
assert o["remote_sessions"] == 0
|
||||
|
||||
def test_max_observable_sessions(self, env_with_ssh: PrimaiteGymEnv):
|
||||
"""Log in from 5 remote places and check that only a max of 3 is shown in the observation."""
|
||||
MAX_OBSERVABLE_SESSIONS = 3
|
||||
# Right now this is hardcoded as 3 in HostObservation, FirewallObservation, and RouterObservation
|
||||
obs, *_ = env_with_ssh.step(0)
|
||||
logins_obs = extract_login_numbers_from_obs(obs)
|
||||
db_srv_logins_obs = logins_obs.pop("HOST2") # this is the index of db server
|
||||
|
||||
db_srv = env_with_ssh.game.simulation.network.get_node_by_hostname("database_server")
|
||||
db_srv.user_session_manager.remote_session_timeout_steps = 20
|
||||
db_srv.user_session_manager.max_remote_sessions = 5
|
||||
node_names = ("client_1", "client_2", "backup_server", "security_suite", "domain_controller")
|
||||
|
||||
for i, node_name in enumerate(node_names):
|
||||
node = env_with_ssh.game.simulation.network.get_node_by_hostname(node_name)
|
||||
node.terminal._send_remote_login("admin", "admin", "192.168.1.14")
|
||||
|
||||
obs, *_ = env_with_ssh.step(0)
|
||||
logins_obs = extract_login_numbers_from_obs(obs)
|
||||
db_srv_logins_obs = logins_obs.pop("HOST2") # this is the index of db server
|
||||
|
||||
assert db_srv_logins_obs["remote_sessions"] == min(MAX_OBSERVABLE_SESSIONS, i + 1)
|
||||
assert len(db_srv.user_session_manager.remote_sessions) == i + 1
|
||||
@@ -72,25 +72,26 @@ def test_uc2_rewards(game_and_agent):
|
||||
request = ["network", "node", "client_1", "application", "DatabaseClient", "execute"]
|
||||
response = game.simulation.apply_request(request)
|
||||
state = game.get_sim_state()
|
||||
reward_value = comp.calculate(
|
||||
state,
|
||||
last_action_response=AgentHistoryItem(
|
||||
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=request, response=response
|
||||
),
|
||||
ahi = AgentHistoryItem(
|
||||
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=request, response=response
|
||||
)
|
||||
reward_value = comp.calculate(state, last_action_response=ahi)
|
||||
assert reward_value == 1.0
|
||||
assert ahi.reward_info == {"connection_attempt_status": "success"}
|
||||
|
||||
router.acl.remove_rule(position=2)
|
||||
|
||||
response = game.simulation.apply_request(request)
|
||||
state = game.get_sim_state()
|
||||
ahi = AgentHistoryItem(
|
||||
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=request, response=response
|
||||
)
|
||||
reward_value = comp.calculate(
|
||||
state,
|
||||
last_action_response=AgentHistoryItem(
|
||||
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=request, response=response
|
||||
),
|
||||
last_action_response=ahi,
|
||||
)
|
||||
assert reward_value == -1.0
|
||||
assert ahi.reward_info == {"connection_attempt_status": "failure"}
|
||||
|
||||
|
||||
def test_shared_reward():
|
||||
|
||||
Reference in New Issue
Block a user