#Bug and test fixes

This commit is contained in:
Charlie Crane
2025-02-14 11:38:15 +00:00
parent 56699d2377
commit 7e138d1d61
7 changed files with 34 additions and 45 deletions

View File

@@ -80,7 +80,7 @@ class AbstractAgent(BaseModel, ABC):
_registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {}
def __init__(self, **kwargs):
"""init"""
"""Initialise and setup agent logger"""
super().__init__(**kwargs)
self.logger: AgentLog = AgentLog(agent_name=kwargs["config"]["ref"])

View File

@@ -163,15 +163,6 @@ class Network(SimComponent):
:param links: Include link details in the output. Defaults to True.
:param markdown: Use Markdown style in table output. Defaults to False.
"""
nodes_type_map = {
"Router": self.router_nodes,
"Firewall": self.firewall_nodes,
"Switch": self.switch_nodes,
"Server": self.server_nodes,
"Computer": self.computer_nodes,
"Printer": self.printer_nodes,
"Wireless Router": self.wireless_router_nodes,
}
if nodes:
table = PrettyTable(["Node", "Type", "Operating State"])
@@ -189,19 +180,18 @@ class Network(SimComponent):
table.set_style(MARKDOWN)
table.align = "l"
table.title = "IP Addresses"
for nodes in nodes_type_map.values():
for node in nodes:
for i, port in node.network_interface.items():
if hasattr(port, "ip_address"):
if port.ip_address != IPv4Address("127.0.0.1"):
port_str = port.port_name if port.port_name else port.port_num
table.add_row(
[
node.config.hostname,
port_str,
port.ip_address,
port.subnet_mask,
node.config.default_gateway,
for node in self.nodes.values():
for i, port in node.network_interface.items():
if hasattr(port, "ip_address"):
if port.ip_address != IPv4Address("127.0.0.1"):
port_str = port.port_name if port.port_name else port.port_num
table.add_row(
[
node.config.hostname,
port_str,
port.ip_address,
port.subnet_mask,
node.config.default_gateway,
]
)
print(table)
@@ -215,22 +205,21 @@ class Network(SimComponent):
table.align = "l"
table.title = "Links"
links = list(self.links.values())
for nodes in nodes_type_map.values():
for node in nodes:
for link in links[::-1]:
if node in [link.endpoint_a.parent, link.endpoint_b.parent]:
table.add_row(
[
link.endpoint_a.parent.config.hostname,
str(link.endpoint_a),
link.endpoint_b.parent.config.hostname,
str(link.endpoint_b),
link.is_up,
link.bandwidth,
link.current_load_percent,
]
)
links.remove(link)
for node in self.nodes.values():
for link in links[::-1]:
if node in [link.endpoint_a.parent, link.endpoint_b.parent]:
table.add_row(
[
link.endpoint_a.parent.config.hostname,
str(link.endpoint_a),
link.endpoint_b.parent.config.hostname,
str(link.endpoint_b),
link.is_up,
link.bandwidth,
link.current_load_percent,
]
)
links.remove(link)
print(table)
def clear_links(self):

View File

@@ -1197,7 +1197,7 @@ class UserSessionManager(Service, discriminator="user-session-manager"):
"""Request should take the form [username, password, remote_ip_address]."""
username, password, remote_ip_address = request
response = RequestResponse.from_bool(self.remote_login(username, password, remote_ip_address))
response.data = {"remote_hostname": self.parent.hostname, "username": username}
response.data = {"remote_hostname": self.parent.config.hostname, "username": username}
return response
rm.add_request("remote_login", RequestType(func=_remote_login))
@@ -1230,7 +1230,7 @@ class UserSessionManager(Service, discriminator="user-session-manager"):
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.parent.hostname} User Sessions"
table.title = f"{self.parent.config.hostname} User Sessions"
def _add_session_to_table(user_session: UserSession):
"""

View File

@@ -20,6 +20,7 @@ from primaite.utils.validation.port import Port
if TYPE_CHECKING:
from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.network.hardware.base import Node
class SoftwareType(Enum):
@@ -110,6 +111,7 @@ class Software(SimComponent, ABC):
"The folder on the file system the Software uses."
_fixing_countdown: Optional[int] = None
"Current number of ticks left to patch the software."
# parent: Optional[Node] = None
def __init__(self, **kwargs):
super().__init__(**kwargs)

View File

@@ -49,7 +49,7 @@ class GigaSwitch(NetworkNode, discriminator="gigaswitch"):
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.hostname} Switch Ports"
table.title = f"{self.config.hostname} Switch Ports"
for port_num, port in self.network_interface.items():
table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"])
print(table)

View File

@@ -106,7 +106,6 @@ def test_remote_login_change_password(game_and_agent_fixture: Tuple[PrimaiteGame
"username": "user123",
"current_password": "password",
"new_password": "different_password",
"remote_ip": str(server_1.network_interface[1].ip_address),
},
)
agent.store_action(action)
@@ -146,7 +145,6 @@ def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGam
"username": "user123",
"current_password": "password",
"new_password": "different_password",
"remote_ip": str(server_1.network_interface[1].ip_address),
},
)
agent.store_action(action)

View File

@@ -4,7 +4,7 @@ from primaite.game.agent.scripted_agents.random_agent import RandomAgent
def test_creating_empty_agent():
agent = RandomAgent()
agent = RandomAgent(config={"ref" :"Empty Agent"})
assert len(agent.action_manager.action_map) == 0
assert isinstance(agent.observation_manager.obs, NullObservation)
assert len(agent.reward_function.reward_components) == 0