Merge remote-tracking branch 'origin/4.0.0a1-dev' into bugfix/network-setup

This commit is contained in:
Marek Wolan
2025-02-17 18:16:21 +00:00
41 changed files with 141 additions and 49 deletions

View File

@@ -52,4 +52,4 @@ class Router(NetworkNode, identifier="router"):
Changes to YAML file.
=====================
While effort has been made to ensure that nodes defined within configuration YAML files for use with PrimAITE 3.X remain compatible with PrimAITE v4+, it is encouraged to review for minor changes needed.
While effort has been made to ensure that nodes defined within configuration YAML files for use with PrimAITE 3.X remain compatible with PrimAITE v4+, it is encouraged to review for minor changes needed.

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
game:
ports:
- ARP

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
game:
ports:
- ARP

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_agent_actions: true
save_step_metadata: false

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_agent_actions: true
save_step_metadata: false

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
game:
max_episode_length: 128
ports: []

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
game:
ports:
- ARP

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_agent_actions: true
save_step_metadata: false

View File

@@ -70,7 +70,7 @@ class AbstractAgent(BaseModel, ABC):
config: ConfigSchema = Field(default_factory=lambda: AbstractAgent.ConfigSchema())
logger: AgentLog = AgentLog(agent_name="Abstract_Agent")
logger: AgentLog = None
history: List[AgentHistoryItem] = []
action_manager: ActionManager = Field(default_factory=lambda: ActionManager())
@@ -79,6 +79,11 @@ class AbstractAgent(BaseModel, ABC):
_registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {}
def __init__(self, **kwargs):
"""Initialise and setup agent logger."""
super().__init__(**kwargs)
self.logger: AgentLog = AgentLog(agent_name=kwargs["config"]["ref"])
def __init_subclass__(cls, discriminator: Optional[str] = None, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
if discriminator is None:

View File

@@ -163,16 +163,6 @@ class Network(SimComponent):
:param links: Include link details in the output. Defaults to True.
:param markdown: Use Markdown style in table output. Defaults to False.
"""
nodes_type_map = {
"Router": self.router_nodes,
"Firewall": self.firewall_nodes,
"Switch": self.switch_nodes,
"Server": self.server_nodes,
"Computer": self.computer_nodes,
"Printer": self.printer_nodes,
"Wireless Router": self.wireless_router_nodes,
}
if nodes:
table = PrettyTable(["Node", "Type", "Operating State"])
if markdown:
@@ -189,21 +179,20 @@ class Network(SimComponent):
table.set_style(MARKDOWN)
table.align = "l"
table.title = "IP Addresses"
for nodes in nodes_type_map.values():
for node in nodes:
for i, port in node.network_interface.items():
if hasattr(port, "ip_address"):
if port.ip_address != IPv4Address("127.0.0.1"):
port_str = port.port_name if port.port_name else port.port_num
table.add_row(
[
node.config.hostname,
port_str,
port.ip_address,
port.subnet_mask,
node.config.default_gateway,
]
)
for node in self.nodes.values():
for i, port in node.network_interface.items():
if hasattr(port, "ip_address"):
if port.ip_address != IPv4Address("127.0.0.1"):
port_str = port.port_name if port.port_name else port.port_num
table.add_row(
[
node.config.hostname,
port_str,
port.ip_address,
port.subnet_mask,
node.config.default_gateway,
]
)
print(table)
if links:
@@ -215,22 +204,21 @@ class Network(SimComponent):
table.align = "l"
table.title = "Links"
links = list(self.links.values())
for nodes in nodes_type_map.values():
for node in nodes:
for link in links[::-1]:
if node in [link.endpoint_a.parent, link.endpoint_b.parent]:
table.add_row(
[
link.endpoint_a.parent.config.hostname,
str(link.endpoint_a),
link.endpoint_b.parent.config.hostname,
str(link.endpoint_b),
link.is_up,
link.bandwidth,
link.current_load_percent,
]
)
links.remove(link)
for node in self.nodes.values():
for link in links[::-1]:
if node in [link.endpoint_a.parent, link.endpoint_b.parent]:
table.add_row(
[
link.endpoint_a.parent.config.hostname,
str(link.endpoint_a),
link.endpoint_b.parent.config.hostname,
str(link.endpoint_b),
link.is_up,
link.bandwidth,
link.current_load_percent,
]
)
links.remove(link)
print(table)
def clear_links(self):

View File

@@ -1195,7 +1195,7 @@ class UserSessionManager(Service, discriminator="user-session-manager"):
"""Request should take the form [username, password, remote_ip_address]."""
username, password, remote_ip_address = request
response = RequestResponse.from_bool(self.remote_login(username, password, remote_ip_address))
response.data = {"remote_hostname": self.parent.hostname, "username": username}
response.data = {"remote_hostname": self.parent.config.hostname, "username": username}
return response
rm.add_request("remote_login", RequestType(func=_remote_login))
@@ -1228,7 +1228,7 @@ class UserSessionManager(Service, discriminator="user-session-manager"):
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.parent.hostname} User Sessions"
table.title = f"{self.parent.config.hostname} User Sessions"
def _add_session_to_table(user_session: UserSession):
"""

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_agent_actions: false
save_step_metadata: false

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
game:
ports:
- ARP

View File

@@ -4,6 +4,9 @@
# | node_a |------| switch_1 |------| node_b |
# -------------- -------------- --------------
#
metadata:
version: 3.0
io_settings:
save_step_metadata: false
save_pcap_logs: true

View File

@@ -4,6 +4,8 @@
# | client_1 |------| switch_1 |------| client_2 |
# -------------- -------------- --------------
#
metadata:
version: 3.0
io_settings:
save_step_metadata: false

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_step_metadata: false
save_pcap_logs: true

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_step_metadata: false
save_pcap_logs: true

View File

@@ -4,6 +4,9 @@
# | client_1 |------| switch_1 |------| client_2 |
# -------------- -------------- --------------
#
metadata:
version: 3.0
io_settings:
save_step_metadata: false
save_pcap_logs: true

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_agent_actions: true
save_step_metadata: false

View File

@@ -30,6 +30,9 @@
# | external_computer |------| switch_3 |------| external_server |
# ----------------------- -------------- ---------------------
#
metadata:
version: 3.0
io_settings:
save_step_metadata: false
save_pcap_logs: true

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
game:
ports:
- ARP

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_agent_actions: true
save_step_metadata: false

View File

@@ -30,6 +30,9 @@
# | external_computer |------| switch_3 |------| external_server |
# ----------------------- -------------- ---------------------
#
metadata:
version: 3.0
io_settings:
save_step_metadata: false
save_pcap_logs: true

View File

@@ -4,6 +4,9 @@
# | client_1 |------| switch_1 |------| client_2 |
# -------------- -------------- --------------
#
metadata:
version: 3.0
io_settings:
save_step_metadata: false
save_pcap_logs: true

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_step_metadata: false
save_pcap_logs: false

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_agent_actions: false
save_step_metadata: false

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_step_metadata: false
save_pcap_logs: true

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_step_metadata: false
save_pcap_logs: true

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_step_metadata: false
save_pcap_logs: true

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_step_metadata: false
save_pcap_logs: true

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_agent_actions: true
save_step_metadata: false

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_agent_actions: false
save_step_metadata: false

View File

@@ -4,6 +4,9 @@
# | client_1 |------| switch_1 |------| client_2 |
# -------------- -------------- --------------
#
metadata:
version: 3.0
io_settings:
save_step_metadata: false
save_pcap_logs: true

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_agent_actions: true
save_step_metadata: false

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_agent_actions: true
save_step_metadata: true

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
game:
max_episode_length: 256
ports:

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
game:
max_episode_length: 256
ports:

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
game:
max_episode_length: 256
ports:

View File

@@ -49,7 +49,7 @@ class GigaSwitch(NetworkNode, discriminator="gigaswitch"):
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.hostname} Switch Ports"
table.title = f"{self.config.hostname} Switch Ports"
for port_num, port in self.network_interface.items():
table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"])
print(table)

View File

@@ -106,7 +106,6 @@ def test_remote_login_change_password(game_and_agent_fixture: Tuple[PrimaiteGame
"username": "user123",
"current_password": "password",
"new_password": "different_password",
"remote_ip": str(server_1.network_interface[1].ip_address),
},
)
agent.store_action(action)
@@ -146,7 +145,6 @@ def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGam
"username": "user123",
"current_password": "password",
"new_password": "different_password",
"remote_ip": str(server_1.network_interface[1].ip_address),
},
)
agent.store_action(action)

View File

@@ -4,7 +4,7 @@ from primaite.game.agent.scripted_agents.random_agent import RandomAgent
def test_creating_empty_agent():
agent = RandomAgent()
agent = RandomAgent(config={"ref": "Empty Agent"})
assert len(agent.action_manager.action_map) == 0
assert isinstance(agent.observation_manager.obs, NullObservation)
assert len(agent.reward_function.reward_components) == 0