Merge branch '4.0.0a1-dev' into UC7-migration

This commit is contained in:
Archer Bowen
2025-02-25 13:52:44 +00:00
committed by Marek Wolan
49 changed files with 172 additions and 74 deletions

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
game:
ports:
- ARP

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
game:
ports:
- ARP

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_agent_actions: true
save_step_metadata: false

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_agent_actions: true
save_step_metadata: false

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
game:
max_episode_length: 128
ports: []

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
game:
ports:
- ARP

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_agent_actions: true
save_step_metadata: false

View File

@@ -77,7 +77,7 @@ class AbstractAgent(BaseModel, ABC):
config: ConfigSchema = Field(default_factory=lambda: AbstractAgent.ConfigSchema())
logger: AgentLog = AgentLog(agent_name="Abstract_Agent")
logger: AgentLog = None
history: List[AgentHistoryItem] = []
action_manager: ActionManager = Field(default_factory=lambda: ActionManager())
@@ -86,6 +86,11 @@ class AbstractAgent(BaseModel, ABC):
_registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {}
def __init__(self, **kwargs):
"""Initialise and setup agent logger."""
super().__init__(**kwargs)
self.logger: AgentLog = AgentLog(agent_name=kwargs["config"]["ref"])
def __init_subclass__(cls, discriminator: Optional[str] = None, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
if discriminator is None:

View File

@@ -449,10 +449,8 @@ class IPWirelessNetworkInterface(WirelessNetworkInterface, Layer3Interface, ABC)
`default_gateway_hello` method is not defined, ignoring such errors to proceed without interruption.
"""
super().enable()
try:
if hasattr(self._connected_node, "default_gateway_hello"):
self._connected_node.default_gateway_hello()
except AttributeError:
pass
@abstractmethod
def receive_frame(self, frame: Frame) -> bool:

View File

@@ -163,16 +163,6 @@ class Network(SimComponent):
:param links: Include link details in the output. Defaults to True.
:param markdown: Use Markdown style in table output. Defaults to False.
"""
nodes_type_map = {
"Router": self.router_nodes,
"Firewall": self.firewall_nodes,
"Switch": self.switch_nodes,
"Server": self.server_nodes,
"Computer": self.computer_nodes,
"Printer": self.printer_nodes,
"Wireless Router": self.wireless_router_nodes,
}
if nodes:
table = PrettyTable(["Node", "Type", "Operating State"])
if markdown:
@@ -189,21 +179,20 @@ class Network(SimComponent):
table.set_style(MARKDOWN)
table.align = "l"
table.title = "IP Addresses"
for nodes in nodes_type_map.values():
for node in nodes:
for i, port in node.network_interface.items():
if hasattr(port, "ip_address"):
if port.ip_address != IPv4Address("127.0.0.1"):
port_str = port.port_name if port.port_name else port.port_num
table.add_row(
[
node.config.hostname,
port_str,
port.ip_address,
port.subnet_mask,
node.config.default_gateway,
]
)
for node in self.nodes.values():
for i, port in node.network_interface.items():
if hasattr(port, "ip_address"):
if port.ip_address != IPv4Address("127.0.0.1"):
port_str = port.port_name if port.port_name else port.port_num
table.add_row(
[
node.config.hostname,
port_str,
port.ip_address,
port.subnet_mask,
node.config.default_gateway,
]
)
print(table)
if links:
@@ -215,22 +204,21 @@ class Network(SimComponent):
table.align = "l"
table.title = "Links"
links = list(self.links.values())
for nodes in nodes_type_map.values():
for node in nodes:
for link in links[::-1]:
if node in [link.endpoint_a.parent, link.endpoint_b.parent]:
table.add_row(
[
link.endpoint_a.parent.config.hostname,
str(link.endpoint_a),
link.endpoint_b.parent.config.hostname,
str(link.endpoint_b),
link.is_up,
link.bandwidth,
link.current_load_percent,
]
)
links.remove(link)
for node in self.nodes.values():
for link in links[::-1]:
if node in [link.endpoint_a.parent, link.endpoint_b.parent]:
table.add_row(
[
link.endpoint_a.parent.config.hostname,
str(link.endpoint_a),
link.endpoint_b.parent.config.hostname,
str(link.endpoint_b),
link.is_up,
link.bandwidth,
link.current_load_percent,
]
)
links.remove(link)
print(table)
def clear_links(self):

View File

@@ -155,7 +155,12 @@ class OfficeLANAdder(NetworkNodeAdder, discriminator="office-lan"):
# Create a core switch if more than one edge switch is needed
if num_of_switches > 1:
core_switch = Switch.from_config(
config={"type": "switch", "hostname": f"switch_core_{config.lan_name}", "start_up_duration": 0}
config={
"type": "switch",
"hostname": f"switch_core_{config.lan_name}",
"start_up_duration": 0,
"num_ports": 24,
}
)
core_switch.power_on()
network.add_node(core_switch)
@@ -183,7 +188,12 @@ class OfficeLANAdder(NetworkNodeAdder, discriminator="office-lan"):
switch_port = 0
switch_n = 1
switch = Switch.from_config(
config={"type": "switch", "hostname": f"switch_edge_{switch_n}_{config.lan_name}", "start_up_duration": 0}
config={
"type": "switch",
"hostname": f"switch_edge_{switch_n}_{config.lan_name}",
"start_up_duration": 0,
"num_ports": 24,
}
)
switch.power_on()
network.add_node(switch)
@@ -207,6 +217,7 @@ class OfficeLANAdder(NetworkNodeAdder, discriminator="office-lan"):
"type": "switch",
"hostname": f"switch_edge_{switch_n}_{config.lan_name}",
"start_up_duration": 0,
"num_ports": 24,
}
)
switch.power_on()

View File

@@ -639,10 +639,8 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC):
`default_gateway_hello` method is not defined, ignoring such errors to proceed without interruption.
"""
super().enable()
try:
if hasattr(self._connected_node, "default_gateway_hello"):
self._connected_node.default_gateway_hello()
except AttributeError:
pass
return True
@abstractmethod
@@ -1211,7 +1209,7 @@ class UserSessionManager(Service, discriminator="user-session-manager"):
"""Request should take the form [username, password, remote_ip_address]."""
username, password, remote_ip_address = request
response = RequestResponse.from_bool(self.remote_login(username, password, remote_ip_address))
response.data = {"remote_hostname": self.parent.hostname, "username": username}
response.data = {"remote_hostname": self.parent.config.hostname, "username": username}
return response
rm.add_request("remote_login", RequestType(func=_remote_login))
@@ -1244,7 +1242,7 @@ class UserSessionManager(Service, discriminator="user-session-manager"):
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.parent.hostname} User Sessions"
table.title = f"{self.parent.config.hostname} User Sessions"
def _add_session_to_table(user_session: UserSession):
"""

View File

@@ -334,7 +334,7 @@ class HostNode(Node, discriminator="host-node"):
class ConfigSchema(Node.ConfigSchema):
"""Configuration Schema for HostNode class."""
type: Literal["host-node"]
type: Literal["host-node"] = "host-node"
hostname: str = "HostNode"
subnet_mask: IPV4Address = "255.255.255.0"
ip_address: IPV4Address
@@ -388,7 +388,7 @@ class HostNode(Node, discriminator="host-node"):
This method is invoked to ensure the host node can communicate with its default gateway, primarily to confirm
network connectivity and populate the ARP cache with the gateway's MAC address.
"""
if self.operating_state == NodeOperatingState.ON and self.default_gateway:
if self.operating_state == NodeOperatingState.ON and self.config.default_gateway:
self.software_manager.arp.get_default_gateway_mac_address()
def receive_frame(self, frame: Frame, from_network_interface: NIC):

View File

@@ -1352,7 +1352,6 @@ class Router(NetworkNode, discriminator="router"):
:return: A dictionary representing the current state.
"""
state = super().describe_state()
state["num_ports"] = self.config.num_ports
state["acl"] = self.acl.describe_state()
return state

View File

@@ -103,8 +103,8 @@ class Switch(NetworkNode, discriminator="switch"):
type: Literal["switch"] = "switch"
hostname: str = "Switch"
num_ports: int = 24
"The number of ports on the switch. Default is 24."
num_ports: int = 8
"The number of ports on the switch."
config: ConfigSchema = Field(default_factory=lambda: Switch.ConfigSchema())
@@ -139,7 +139,6 @@ class Switch(NetworkNode, discriminator="switch"):
"""
state = super().describe_state()
state["ports"] = {port_num: port.describe_state() for port_num, port in self.network_interface.items()}
state["num_ports"] = self.config.num_ports # redundant?
state["mac_address_table"] = {mac: port.port_num for mac, port in self.mac_address_table.items()}
return state

View File

@@ -138,8 +138,8 @@ class ARP(Service, discriminator="arp"):
break
if use_default_gateway:
if self.software_manager.node.default_gateway:
target_ip_address = self.software_manager.node.default_gateway
if self.software_manager.node.config.default_gateway:
target_ip_address = self.software_manager.node.config.default_gateway
else:
return

View File

@@ -82,7 +82,7 @@ class Software(SimComponent, ABC):
"""Configurable options for all software."""
model_config = ConfigDict(extra="forbid")
starting_health_state: SoftwareHealthState = SoftwareHealthState.UNUSED
starting_health_state: SoftwareHealthState = SoftwareHealthState.GOOD
criticality: SoftwareCriticality = SoftwareCriticality.LOWEST
fixing_duration: int = 2

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_agent_actions: false
save_step_metadata: false

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
game:
ports:
- ARP

View File

@@ -4,6 +4,9 @@
# | node_a |------| switch_1 |------| node_b |
# -------------- -------------- --------------
#
metadata:
version: 3.0
io_settings:
save_step_metadata: false
save_pcap_logs: true

View File

@@ -4,6 +4,8 @@
# | client_1 |------| switch_1 |------| client_2 |
# -------------- -------------- --------------
#
metadata:
version: 3.0
io_settings:
save_step_metadata: false

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_step_metadata: false
save_pcap_logs: true

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_step_metadata: false
save_pcap_logs: true

View File

@@ -4,6 +4,9 @@
# | client_1 |------| switch_1 |------| client_2 |
# -------------- -------------- --------------
#
metadata:
version: 3.0
io_settings:
save_step_metadata: false
save_pcap_logs: true

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_agent_actions: true
save_step_metadata: false

View File

@@ -30,6 +30,9 @@
# | external_computer |------| switch_3 |------| external_server |
# ----------------------- -------------- ---------------------
#
metadata:
version: 3.0
io_settings:
save_step_metadata: false
save_pcap_logs: true

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
game:
ports:
- ARP

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_agent_actions: true
save_step_metadata: false

View File

@@ -30,6 +30,9 @@
# | external_computer |------| switch_3 |------| external_server |
# ----------------------- -------------- ---------------------
#
metadata:
version: 3.0
io_settings:
save_step_metadata: false
save_pcap_logs: true

View File

@@ -4,6 +4,9 @@
# | client_1 |------| switch_1 |------| client_2 |
# -------------- -------------- --------------
#
metadata:
version: 3.0
io_settings:
save_step_metadata: false
save_pcap_logs: true

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_step_metadata: false
save_pcap_logs: false

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_agent_actions: false
save_step_metadata: false

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_step_metadata: false
save_pcap_logs: true

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_step_metadata: false
save_pcap_logs: true

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_step_metadata: false
save_pcap_logs: true

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_step_metadata: false
save_pcap_logs: true

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_agent_actions: true
save_step_metadata: false

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_agent_actions: false
save_step_metadata: false

View File

@@ -4,6 +4,9 @@
# | client_1 |------| switch_1 |------| client_2 |
# -------------- -------------- --------------
#
metadata:
version: 3.0
io_settings:
save_step_metadata: false
save_pcap_logs: true

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_agent_actions: true
save_step_metadata: false

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
io_settings:
save_agent_actions: true
save_step_metadata: true

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
game:
max_episode_length: 256
ports:

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
game:
max_episode_length: 256
ports:

View File

@@ -1,3 +1,6 @@
metadata:
version: 3.0
game:
max_episode_length: 256
ports:

View File

@@ -4,7 +4,7 @@ from primaite.game.agent.scripted_agents.random_agent import RandomAgent
def test_creating_empty_agent():
agent = RandomAgent()
agent = RandomAgent(config={"ref": "Empty Agent"})
assert len(agent.action_manager.action_map) == 0
assert isinstance(agent.observation_manager.obs, NullObservation)
assert len(agent.reward_function.reward_components) == 0

View File

@@ -17,4 +17,3 @@ def switch() -> Switch:
def test_describe_state(switch):
state = switch.describe_state()
assert len(state.get("ports")) is 8
assert state.get("num_ports") is 8

View File

@@ -18,7 +18,7 @@ def test_scan(application):
def test_run_application(application):
assert application.operating_state == ApplicationOperatingState.CLOSED
assert application.health_state_actual == SoftwareHealthState.UNUSED
assert application.health_state_actual == SoftwareHealthState.GOOD
application.run()
assert application.operating_state == ApplicationOperatingState.RUNNING
@@ -37,9 +37,9 @@ def test_close_application(application):
def test_application_describe_states(application):
assert application.operating_state == ApplicationOperatingState.CLOSED
assert application.health_state_actual == SoftwareHealthState.UNUSED
assert application.health_state_actual == SoftwareHealthState.GOOD
assert SoftwareHealthState.UNUSED.value == application.describe_state().get("health_state_actual")
assert SoftwareHealthState.GOOD.value == application.describe_state().get("health_state_actual")
application.run()
assert SoftwareHealthState.GOOD.value == application.describe_state().get("health_state_actual")

View File

@@ -22,7 +22,7 @@ def test_scan(service):
def test_start_service(service):
assert service.operating_state == ServiceOperatingState.STOPPED
assert service.health_state_actual == SoftwareHealthState.UNUSED
assert service.health_state_actual == SoftwareHealthState.GOOD
service.start()
assert service.operating_state == ServiceOperatingState.RUNNING
@@ -43,7 +43,7 @@ def test_pause_and_resume_service(service):
assert service.operating_state == ServiceOperatingState.STOPPED
service.resume()
assert service.operating_state == ServiceOperatingState.STOPPED
assert service.health_state_actual == SoftwareHealthState.UNUSED
assert service.health_state_actual == SoftwareHealthState.GOOD
service.start()
assert service.health_state_actual == SoftwareHealthState.GOOD
@@ -58,11 +58,11 @@ def test_pause_and_resume_service(service):
def test_restart(service):
assert service.operating_state == ServiceOperatingState.STOPPED
assert service.health_state_actual == SoftwareHealthState.UNUSED
assert service.health_state_actual == SoftwareHealthState.GOOD
service.restart()
# Service is STOPPED. Restart will only work if the service was PAUSED or RUNNING
assert service.operating_state == ServiceOperatingState.STOPPED
assert service.health_state_actual == SoftwareHealthState.UNUSED
assert service.health_state_actual == SoftwareHealthState.GOOD
service.start()
assert service.operating_state == ServiceOperatingState.RUNNING
@@ -157,11 +157,11 @@ def test_service_fixing(service):
def test_enable_disable(service):
service.disable()
assert service.operating_state == ServiceOperatingState.DISABLED
assert service.health_state_actual == SoftwareHealthState.UNUSED
assert service.health_state_actual == SoftwareHealthState.GOOD
service.enable()
assert service.operating_state == ServiceOperatingState.STOPPED
assert service.health_state_actual == SoftwareHealthState.UNUSED
assert service.health_state_actual == SoftwareHealthState.GOOD
def test_overwhelm_service(service):

View File

@@ -39,6 +39,6 @@ def test_software_creation(software):
def test_software_set_health_state(software):
assert software.health_state_actual == SoftwareHealthState.UNUSED
software.set_health_state(SoftwareHealthState.GOOD)
assert software.health_state_actual == SoftwareHealthState.GOOD
software.set_health_state(SoftwareHealthState.COMPROMISED)
assert software.health_state_actual == SoftwareHealthState.COMPROMISED