#2769 - Add user login observations
This commit is contained in:
@@ -10,6 +10,7 @@ from primaite import getLogger
|
||||
from primaite.game.agent.observations.acl_observation import ACLObservation
|
||||
from primaite.game.agent.observations.nic_observations import PortObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -32,6 +33,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
"""List of protocols for encoding ACLs."""
|
||||
num_rules: Optional[int] = None
|
||||
"""Number of rules ACL rules to show."""
|
||||
include_users: Optional[bool] = True
|
||||
"""If True, report user session information."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -41,6 +44,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
port_list: List[int],
|
||||
protocol_list: List[str],
|
||||
num_rules: int,
|
||||
include_users: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a firewall observation instance.
|
||||
@@ -58,9 +62,13 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
:type protocol_list: List[str]
|
||||
:param num_rules: Number of rules configured in the firewall.
|
||||
:type num_rules: int
|
||||
:param include_users: If True, report user session information.
|
||||
:type include_users: bool
|
||||
"""
|
||||
self.where: WhereType = where
|
||||
|
||||
self.include_users: bool = include_users
|
||||
self.max_users: int = 3
|
||||
"""Maximum number of remote sessions observable, excess sessions are truncated."""
|
||||
self.ports: List[PortObservation] = [
|
||||
PortObservation(where=self.where + ["NICs", port_num]) for port_num in (1, 2, 3)
|
||||
]
|
||||
@@ -142,6 +150,9 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
:return: Observation containing the status of ports and ACLs for internal, DMZ, and external traffic.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
firewall_state = access_from_nested_dict(state, self.where)
|
||||
if firewall_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
obs = {
|
||||
"PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)},
|
||||
"ACL": {
|
||||
@@ -159,6 +170,12 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
},
|
||||
},
|
||||
}
|
||||
if self.include_users:
|
||||
sess = firewall_state["services"]["UserSessionManager"]
|
||||
obs["users"] = {
|
||||
"local_login": 1 if sess["current_local_user"] else 0,
|
||||
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
|
||||
}
|
||||
return obs
|
||||
|
||||
@property
|
||||
@@ -218,4 +235,5 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
port_list=config.port_list,
|
||||
protocol_list=config.protocol_list,
|
||||
num_rules=config.num_rules,
|
||||
include_users=config.include_users,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user