Merged PR 490: #2735: fixes to broken items

## Summary
*Replace this text with an explanation of what the changes are and how you implemented them. Can this impact any other parts of the codebase that we should keep in mind?*

## Test process
*How have you tested this (if applicable)?*

## Checklist
- [ ] PR is linked to a **work item**
- [ ] **acceptance criteria** of linked ticket are met
- [ ] performed **self-review** of the code
- [ ] written **tests** for any new functionality added with this PR
- [ ] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [ ] updated the **change log**
- [ ] ran **pre-commit** checks for code style
- [ ] attended to any **TO-DOs** left in the code

#2735: fixes to broken items

Related work items: #2735
This commit is contained in:
Czar Echavez
2024-08-01 22:56:31 +00:00
committed by Christopher McCarthy
16 changed files with 127 additions and 166 deletions

View File

@@ -102,9 +102,7 @@ stages:
version: '2.1.x'
- script: |
coverage run -m --source=primaite pytest -v -o junit_family=xunit2 --junitxml=junit/test-results.xml --cov-fail-under=80
coverage xml -o coverage.xml -i
coverage html -d htmlcov -i
python run_test_and_coverage.py
displayName: 'Run tests and code coverage'
# Run the notebooks

View File

@@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **Transmission Feasibility Check**: Updated `_can_transmit` function in `Link` to account for current load and total bandwidth capacity, ensuring transmissions do not exceed limits.
- **Frame Size Details**: Frame `size` attribute now includes both core size and payload size in bytes.
- **Transmission Blocking**: Enhanced `AirSpace` logic to block transmissions that would exceed the available capacity.
- **Software (un)install refactored**: Removed the install/uninstall methods in the node class and made the software manager install/uninstall handle all of their functionality.
### Fixed

View File

@@ -5,7 +5,7 @@ from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Final, Tuple
from report import build_benchmark_latex_report
from report import build_benchmark_md_report
from stable_baselines3 import PPO
import primaite
@@ -188,7 +188,7 @@ def run(
with open(_SESSION_METADATA_ROOT / f"{i}.json", "r") as file:
session_metadata_dict[i] = json.load(file)
# generate report
build_benchmark_latex_report(
build_benchmark_md_report(
benchmark_start_time=benchmark_start_time,
session_metadata=session_metadata_dict,
config_path=data_manipulation_config_path(),

View File

@@ -234,10 +234,7 @@ def _plot_av_s_per_100_steps_10_nodes(
"""
major_v = primaite.__version__.split(".")[0]
title = f"Performance of Minor and Bugfix Releases for Major Version {major_v}"
subtitle = (
f"Average Training Time per 100 Steps on 10 Nodes "
f"(target: <= {PLOT_CONFIG['av_s_per_100_steps_10_nodes_benchmark_threshold']} seconds)"
)
subtitle = "Average Training Time per 100 Steps on 10 Nodes "
title = f"{title} <br><sub>{subtitle}</sub>"
layout = go.Layout(
@@ -250,10 +247,6 @@ def _plot_av_s_per_100_steps_10_nodes(
versions = sorted(list(version_times_dict.keys()))
times = [version_times_dict[version] for version in versions]
av_s_per_100_steps_10_nodes_benchmark_threshold = PLOT_CONFIG["av_s_per_100_steps_10_nodes_benchmark_threshold"]
# Calculate the appropriate maximum y-axis value
max_y_axis_value = max(max(times), av_s_per_100_steps_10_nodes_benchmark_threshold) + 1
fig.add_trace(
go.Bar(
@@ -267,7 +260,6 @@ def _plot_av_s_per_100_steps_10_nodes(
fig.update_layout(
xaxis_title="PrimAITE Version",
yaxis_title="Avg Time per 100 Steps on 10 Nodes (seconds)",
yaxis=dict(range=[0, max_y_axis_value]),
title=title,
)

View File

@@ -52,7 +52,7 @@ license-files = ["LICENSE"]
[project.optional-dependencies]
rl = [
"ray[rllib] >= 2.20.0, < 3",
"ray[rllib] >= 2.20.0, <2.33",
"tensorflow==2.12.0",
"stable-baselines3[extra]==2.1.0",
"sb3-contrib==2.1.0",

22
run_test_and_coverage.py Normal file
View File

@@ -0,0 +1,22 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import subprocess
import sys
from typing import Any
def run_command(command: Any):
"""Runs a command and returns the exit code."""
result = subprocess.run(command, shell=True)
if result.returncode != 0:
sys.exit(result.returncode)
# Run pytest with coverage
run_command(
"coverage run -m --source=primaite pytest -v -o junit_family=xunit2 "
"--junitxml=junit/test-results.xml --cov-fail-under=80"
)
# Generate coverage reports if tests passed
run_command("coverage xml -o coverage.xml -i")
run_command("coverage html -d htmlcov -i")

View File

@@ -6,7 +6,7 @@ import secrets
from abc import ABC, abstractmethod
from ipaddress import IPv4Address, IPv4Network
from pathlib import Path
from typing import Any, Dict, List, Optional, TypeVar, Union
from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import BaseModel, Field, validate_call
@@ -39,7 +39,7 @@ from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.core.sys_log import SysLog
from primaite.simulator.system.processes.process import Process
from primaite.simulator.system.services.service import Service
from primaite.simulator.system.software import IOSoftware
from primaite.simulator.system.software import IOSoftware, Software
from primaite.utils.converters import convert_dict_enum_keys_to_enum_values
from primaite.utils.validators import IPV4Address
@@ -850,7 +850,6 @@ class UserManager(Service):
kwargs["port"] = Port.NONE
kwargs["protocol"] = IPProtocol.NONE
super().__init__(**kwargs)
self._request_manager = None
self.start()
@@ -898,6 +897,10 @@ class UserManager(Service):
table.add_row([user.username, user.is_admin, user.disabled])
print(table.get_string(sortby="Username"))
def install(self) -> None:
"""Setup default user during first-time installation."""
self.add_user(username="admin", password="admin", is_admin=True, bypass_can_perform_action=True)
def _is_last_admin(self, username: str) -> bool:
return username in self.admins and len(self.admins) == 1
@@ -1101,9 +1104,6 @@ class UserSessionManager(Service):
This class handles authentication, session management, and session timeouts for users interacting with the Node.
"""
node: Node
"""The node associated with this UserSessionManager."""
local_session: Optional[UserSession] = None
"""The current local user session, if any."""
@@ -1184,7 +1184,7 @@ class UserSessionManager(Service):
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.node.hostname} User Sessions"
table.title = f"{self.parent.hostname} User Sessions"
def _add_session_to_table(user_session: UserSession):
"""
@@ -1473,6 +1473,9 @@ class Node(SimComponent):
red_scan_countdown: int = 0
"Time steps until reveal to red scan is complete."
SYSTEM_SOFTWARE: ClassVar[Dict[str, Type[Software]]] = {}
"Base system software that must be preinstalled."
def __init__(self, **kwargs):
"""
Initialize the Node with various components and managers.
@@ -1497,22 +1500,19 @@ class Node(SimComponent):
dns_server=kwargs.get("dns_server"),
)
super().__init__(**kwargs)
self._install_system_software()
self.session_manager.node = self
self.session_manager.software_manager = self.software_manager
self.software_manager.install(UserSessionManager, node=self)
self.software_manager.install(UserManager)
self.user_manager.add_user(username="admin", password="admin", is_admin=True, bypass_can_perform_action=True)
self._install_system_software()
@property
def user_manager(self) -> UserManager:
"""The Nodes User Manager."""
return self.software_manager.software["UserManager"] # noqa
return self.software_manager.software.get("UserManager") # noqa
@property
def user_session_manager(self) -> UserSessionManager:
"""The Nodes User Session Manager."""
return self.software_manager.software["UserSessionManager"] # noqa
return self.software_manager.software.get("UserSessionManager") # noqa
def local_login(self, username: str, password: str) -> Optional[str]:
"""
@@ -1735,10 +1735,6 @@ class Node(SimComponent):
self._application_manager.add_request(name="install", request_type=RequestType(func=_install_application))
self._application_manager.add_request(name="uninstall", request_type=RequestType(func=_uninstall_application))
rm.add_request("accounts", RequestType(func=self.user_manager._request_manager)) # noqa
rm.add_request("sessions", RequestType(func=self.user_session_manager._request_manager)) # noqa
return rm
def describe_state(self) -> Dict:
@@ -2101,74 +2097,6 @@ class Node(SimComponent):
else:
return
def install_service(self, service: Service) -> None:
"""
Install a service on this node.
:param service: Service instance that has not been installed on any node yet.
:type service: Service
"""
if service in self:
_LOGGER.warning(f"Can't add service {service.name} to node {self.hostname}. It's already installed.")
return
self.services[service.uuid] = service
service.parent = self
service.install() # Perform any additional setup, such as creating files for this service on the node.
self.sys_log.info(f"Installed service {service.name}")
_LOGGER.debug(f"Added service {service.name} to node {self.hostname}")
self._service_request_manager.add_request(service.name, RequestType(func=service._request_manager))
def uninstall_service(self, service: Service) -> None:
"""
Uninstall and completely remove service from this node.
:param service: Service object that is currently associated with this node.
:type service: Service
"""
if service not in self:
_LOGGER.warning(f"Can't remove service {service.name} from node {self.hostname}. It's not installed.")
return
service.uninstall() # Perform additional teardown, such as removing files or restarting the machine.
self.services.pop(service.uuid)
service.parent = None
self.sys_log.info(f"Uninstalled service {service.name}")
self._service_request_manager.remove_request(service.name)
def install_application(self, application: Application) -> None:
"""
Install an application on this node.
:param application: Application instance that has not been installed on any node yet.
:type application: Application
"""
if application in self:
_LOGGER.warning(
f"Can't add application {application.name} to node {self.hostname}. It's already installed."
)
return
self.applications[application.uuid] = application
application.parent = self
self.sys_log.info(f"Installed application {application.name}")
_LOGGER.debug(f"Added application {application.name} to node {self.hostname}")
self._application_request_manager.add_request(application.name, RequestType(func=application._request_manager))
def uninstall_application(self, application: Application) -> None:
"""
Uninstall and completely remove application from this node.
:param application: Application object that is currently associated with this node.
:type application: Application
"""
if application not in self:
_LOGGER.warning(
f"Can't remove application {application.name} from node {self.hostname}. It's not installed."
)
return
self.applications.pop(application.uuid)
application.parent = None
self.sys_log.info(f"Uninstalled application {application.name}")
self._application_request_manager.remove_request(application.name)
def _shut_down_actions(self):
"""Actions to perform when the node is shut down."""
# Turn off all the services in the node
@@ -2197,6 +2125,11 @@ class Node(SimComponent):
# for process_id in self.processes:
# self.processes[process_id]
def _install_system_software(self) -> None:
"""Preinstall required software."""
for _, software_class in self.SYSTEM_SOFTWARE.items():
self.software_manager.install(software_class)
def __contains__(self, item: Any) -> bool:
if isinstance(item, Service):
return item.uuid in self.services

View File

@@ -5,7 +5,13 @@ from ipaddress import IPv4Address
from typing import Any, ClassVar, Dict, Optional
from primaite import getLogger
from primaite.simulator.network.hardware.base import IPWiredNetworkInterface, Link, Node
from primaite.simulator.network.hardware.base import (
IPWiredNetworkInterface,
Link,
Node,
UserManager,
UserSessionManager,
)
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.transmission.data_link_layer import Frame
from primaite.simulator.system.applications.application import ApplicationOperatingState
@@ -306,8 +312,8 @@ class HostNode(Node):
"NTPClient": NTPClient,
"WebBrowser": WebBrowser,
"NMAP": NMAP,
# "UserSessionManager": UserSessionManager,
# "UserManager": UserManager,
"UserSessionManager": UserSessionManager,
"UserManager": UserManager,
}
"""List of system software that is automatically installed on nodes."""
@@ -340,16 +346,6 @@ class HostNode(Node):
"""
return self.software_manager.software.get("ARP")
def _install_system_software(self):
"""
Installs the system software and network services typically found on an operating system.
This method equips the host with essential network services and applications, preparing it for various
network-related tasks and operations.
"""
for _, software_class in self.SYSTEM_SOFTWARE.items():
self.software_manager.install(software_class)
def default_gateway_hello(self):
"""
Sends a hello message to the default gateway to establish connectivity and resolve the gateway's MAC address.

View File

@@ -4,14 +4,14 @@ from __future__ import annotations
import secrets
from enum import Enum
from ipaddress import IPv4Address, IPv4Network
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import validate_call
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.network.hardware.base import IPWiredNetworkInterface
from primaite.simulator.network.hardware.base import IPWiredNetworkInterface, UserManager, UserSessionManager
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode
from primaite.simulator.network.protocols.arp import ARPPacket
@@ -1200,6 +1200,11 @@ class Router(NetworkNode):
RouteTable, RouterARP, and RouterICMP services.
"""
SYSTEM_SOFTWARE: ClassVar[Dict] = {
"UserSessionManager": UserSessionManager,
"UserManager": UserManager,
}
num_ports: int
network_interfaces: Dict[str, RouterInterface] = {}
"The Router Interfaces on the node."
@@ -1235,6 +1240,7 @@ class Router(NetworkNode):
resolution within the network. These services are crucial for the router's operation, enabling it to manage
network traffic efficiently.
"""
super()._install_system_software()
self.software_manager.install(RouterICMP)
icmp: RouterICMP = self.software_manager.icmp # noqa
icmp.router = self

View File

@@ -108,6 +108,9 @@ class Switch(NetworkNode):
for i in range(1, self.num_ports + 1):
self.connect_nic(SwitchPort())
def _install_system_software(self):
pass
def show(self, markdown: bool = False):
"""
Prints a table of the SwitchPorts on the Switch.

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from prettytable import MARKDOWN, PrettyTable
from primaite.simulator.core import RequestType
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.transmission.data_link_layer import Frame
from primaite.simulator.network.transmission.network_layer import IPProtocol
@@ -20,9 +21,7 @@ if TYPE_CHECKING:
from primaite.simulator.system.services.arp.arp import ARP
from primaite.simulator.system.services.icmp.icmp import ICMP
from typing import Type, TypeVar
IOSoftwareClass = TypeVar("IOSoftwareClass", bound=IOSoftware)
from typing import Type
class SoftwareManager:
@@ -51,7 +50,7 @@ class SoftwareManager:
self.node = parent_node
self.session_manager = session_manager
self.software: Dict[str, Union[Service, Application]] = {}
self._software_class_to_name_map: Dict[Type[IOSoftwareClass], str] = {}
self._software_class_to_name_map: Dict[Type[IOSoftware], str] = {}
self.port_protocol_mapping: Dict[Tuple[Port, IPProtocol], Union[Service, Application]] = {}
self.sys_log: SysLog = sys_log
self.file_system: FileSystem = file_system
@@ -104,14 +103,12 @@ class SoftwareManager:
return True
return False
def install(self, software_class: Type[IOSoftwareClass], **install_kwargs) -> None:
def install(self, software_class: Type[IOSoftware], **install_kwargs):
"""
Install an Application or Service.
:param software_class: The software class.
"""
# TODO: Software manager and node itself both have an install method. Need to refactor to have more logical
# separation of concerns.
if software_class in self._software_class_to_name_map:
self.sys_log.warning(f"Cannot install {software_class} as it is already installed")
return
@@ -122,19 +119,22 @@ class SoftwareManager:
dns_server=self.dns_server,
**install_kwargs,
)
software.parent = self.node
if isinstance(software, Application):
software.install()
self.node.applications[software.uuid] = software
self.node._application_request_manager.add_request(
software.name, RequestType(func=software._request_manager)
)
elif isinstance(software, Service):
self.node.services[software.uuid] = software
self.node._service_request_manager.add_request(software.name, RequestType(func=software._request_manager))
software.install()
software.software_manager = self
self.software[software.name] = software
self.port_protocol_mapping[(software.port, software.protocol)] = software
if isinstance(software, Application):
software.operating_state = ApplicationOperatingState.CLOSED
# add the software to the node's registry after it has been fully initialized
if isinstance(software, Service):
self.node.install_service(software)
elif isinstance(software, Application):
self.node.install_application(software)
self.node.sys_log.info(f"Installed {software.name}")
def uninstall(self, software_name: str):
"""
@@ -142,25 +142,31 @@ class SoftwareManager:
:param software_name: The software name.
"""
if software_name in self.software:
self.software[software_name].uninstall()
software = self.software.pop(software_name) # noqa
if isinstance(software, Application):
self.node.uninstall_application(software)
elif isinstance(software, Service):
self.node.uninstall_service(software)
for key, value in self.port_protocol_mapping.items():
if value.name == software_name:
self.port_protocol_mapping.pop(key)
break
for key, value in self._software_class_to_name_map.items():
if value == software_name:
self._software_class_to_name_map.pop(key)
break
del software
self.sys_log.info(f"Uninstalled {software_name}")
if software_name not in self.software:
self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed")
return
self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed")
self.software[software_name].uninstall()
software = self.software.pop(software_name) # noqa
if isinstance(software, Application):
self.node.applications.pop(software.uuid)
self.node._application_request_manager.remove_request(software.name)
elif isinstance(software, Service):
self.node.services.pop(software.uuid)
software.uninstall()
self.node._service_request_manager.remove_request(software.name)
software.parent = None
for key, value in self.port_protocol_mapping.items():
if value.name == software_name:
self.port_protocol_mapping.pop(key)
break
for key, value in self._software_class_to_name_map.items():
if value == software_name:
self._software_class_to_name_map.pop(key)
break
del software
self.sys_log.info(f"Uninstalled {software_name}")
return
def send_internal_payload(self, target_software: str, payload: Any):
"""

View File

@@ -37,14 +37,14 @@ ACTION_SPACE_NODE_ACTION_VALUES = 1
_LOGGER = getLogger(__name__)
class TestService(Service):
class DummyService(Service):
"""Test Service class"""
def describe_state(self) -> Dict:
return super().describe_state()
def __init__(self, **kwargs):
kwargs["name"] = "TestService"
kwargs["name"] = "DummyService"
kwargs["port"] = Port.HTTP
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
@@ -75,15 +75,15 @@ def uc2_network() -> Network:
@pytest.fixture(scope="function")
def service(file_system) -> TestService:
return TestService(
name="TestService", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="test_service")
def service(file_system) -> DummyService:
return DummyService(
name="DummyService", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="dummy_service")
)
@pytest.fixture(scope="function")
def service_class():
return TestService
return DummyService
@pytest.fixture(scope="function")

View File

@@ -22,8 +22,7 @@ def test_passing_actions_down(monkeypatch) -> None:
for n in [pc1, pc2, srv, s1]:
sim.network.add_node(n)
database_service = DatabaseService(file_system=srv.file_system)
srv.install_service(database_service)
srv.software_manager.install(DatabaseService)
downloads_folder = pc1.file_system.create_folder("downloads")
pc1.file_system.create_file("bermuda_triangle.png", folder_name="downloads")

View File

@@ -23,7 +23,7 @@ def populated_node(
server.power_on()
server.software_manager.install(service_class)
service = server.software_manager.software.get("TestService")
service = server.software_manager.software.get("DummyService")
service.start()
return server, service
@@ -42,7 +42,7 @@ def test_service_on_offline_node(service_class):
computer.power_on()
computer.software_manager.install(service_class)
service: Service = computer.software_manager.software.get("TestService")
service: Service = computer.software_manager.software.get("DummyService")
computer.power_off()

View File

@@ -13,7 +13,7 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.simulator.network.transmission.transport_layer import Port
from tests.conftest import DummyApplication, TestService
from tests.conftest import DummyApplication, DummyService
def test_successful_node_file_system_creation_request(example_network):
@@ -61,7 +61,7 @@ def test_successful_application_requests(example_network):
def test_successful_service_requests(example_network):
net = example_network
server_1 = net.get_node_by_hostname("server_1")
server_1.software_manager.install(TestService)
server_1.software_manager.install(DummyService)
# Careful: the order here is important, for example we cannot run "stop" unless we run "start" first
for verb in [
@@ -77,7 +77,7 @@ def test_successful_service_requests(example_network):
"scan",
"fix",
]:
resp_1 = net.apply_request(["node", "server_1", "service", "TestService", verb])
resp_1 = net.apply_request(["node", "server_1", "service", "DummyService", verb])
assert resp_1 == RequestResponse(status="success", data={})
server_1.apply_timestep(timestep=1)
server_1.apply_timestep(timestep=1)

View File

@@ -7,6 +7,7 @@ from primaite.simulator.file_system.folder import Folder
from primaite.simulator.network.hardware.base import Node, NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.system.software import SoftwareHealthState
from tests.conftest import DummyApplication, DummyService
@pytest.fixture
@@ -47,7 +48,7 @@ def test_node_shutdown(node):
assert node.operating_state == NodeOperatingState.OFF
def test_node_os_scan(node, service, application):
def test_node_os_scan(node):
"""Test OS Scanning."""
node.operating_state = NodeOperatingState.ON
@@ -55,13 +56,15 @@ def test_node_os_scan(node, service, application):
# TODO implement processes
# add services to node
node.software_manager.install(DummyService)
service = node.software_manager.software.get("DummyService")
service.set_health_state(SoftwareHealthState.COMPROMISED)
node.install_service(service=service)
assert service.health_state_visible == SoftwareHealthState.UNUSED
# add application to node
node.software_manager.install(DummyApplication)
application = node.software_manager.software.get("DummyApplication")
application.set_health_state(SoftwareHealthState.COMPROMISED)
node.install_application(application=application)
assert application.health_state_visible == SoftwareHealthState.UNUSED
# add folder and file to node
@@ -91,7 +94,7 @@ def test_node_os_scan(node, service, application):
assert file2.visible_health_status == FileSystemItemHealthStatus.CORRUPT
def test_node_red_scan(node, service, application):
def test_node_red_scan(node):
"""Test revealing to red"""
node.operating_state = NodeOperatingState.ON
@@ -99,12 +102,14 @@ def test_node_red_scan(node, service, application):
# TODO implement processes
# add services to node
node.install_service(service=service)
node.software_manager.install(DummyService)
service = node.software_manager.software.get("DummyService")
assert service.revealed_to_red is False
# add application to node
node.software_manager.install(DummyApplication)
application = node.software_manager.software.get("DummyApplication")
application.set_health_state(SoftwareHealthState.COMPROMISED)
node.install_application(application=application)
assert application.revealed_to_red is False
# add folder and file to node