#2769 - Add user login observations

This commit is contained in:
Marek Wolan
2024-08-15 20:16:11 +01:00
parent 7b7adc20f9
commit 1d2705eb1b
7 changed files with 1096 additions and 4 deletions

View File

@@ -10,6 +10,7 @@ from primaite import getLogger
from primaite.game.agent.observations.acl_observation import ACLObservation
from primaite.game.agent.observations.nic_observations import PortObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
@@ -32,6 +33,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
"""List of protocols for encoding ACLs."""
num_rules: Optional[int] = None
"""Number of rules ACL rules to show."""
include_users: Optional[bool] = True
"""If True, report user session information."""
def __init__(
self,
@@ -41,6 +44,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
port_list: List[int],
protocol_list: List[str],
num_rules: int,
include_users: bool,
) -> None:
"""
Initialise a firewall observation instance.
@@ -58,9 +62,13 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
:type protocol_list: List[str]
:param num_rules: Number of rules configured in the firewall.
:type num_rules: int
:param include_users: If True, report user session information.
:type include_users: bool
"""
self.where: WhereType = where
self.include_users: bool = include_users
self.max_users: int = 3
"""Maximum number of remote sessions observable, excess sessions are truncated."""
self.ports: List[PortObservation] = [
PortObservation(where=self.where + ["NICs", port_num]) for port_num in (1, 2, 3)
]
@@ -142,6 +150,9 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
:return: Observation containing the status of ports and ACLs for internal, DMZ, and external traffic.
:rtype: ObsType
"""
firewall_state = access_from_nested_dict(state, self.where)
if firewall_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {
"PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)},
"ACL": {
@@ -159,6 +170,12 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
},
},
}
if self.include_users:
sess = firewall_state["services"]["UserSessionManager"]
obs["users"] = {
"local_login": 1 if sess["current_local_user"] else 0,
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
}
return obs
@property
@@ -218,4 +235,5 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
port_list=config.port_list,
protocol_list=config.protocol_list,
num_rules=config.num_rules,
include_users=config.include_users,
)