#3062 - Remove discriminators from abstract classes and fix remaining old discriminator names

This commit is contained in:
Marek Wolan
2025-02-05 10:12:13 +00:00
parent 0a6b604afd
commit 4a472c5c75
28 changed files with 55 additions and 79 deletions

View File

@@ -12,8 +12,6 @@ from primaite.interface.request import RequestFormat
class AbstractAction(BaseModel, ABC): class AbstractAction(BaseModel, ABC):
"""Base class for actions.""" """Base class for actions."""
config: "AbstractAction.ConfigSchema"
class ConfigSchema(BaseModel, ABC): class ConfigSchema(BaseModel, ABC):
"""Base configuration schema for Actions.""" """Base configuration schema for Actions."""

View File

@@ -21,9 +21,7 @@ __all__ = (
class ACLAddRuleAbstractAction(AbstractAction, ABC): class ACLAddRuleAbstractAction(AbstractAction, ABC):
"""Base abstract class for ACL add rule actions.""" """Base abstract class for ACL add rule actions."""
config: ConfigSchema = "ACLAddRuleAbstractAction.ConfigSchema" class ConfigSchema(AbstractAction.ConfigSchema, ABC):
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration Schema base for ACL add rule abstract actions.""" """Configuration Schema base for ACL add rule abstract actions."""
src_ip: Union[IPV4Address, Literal["ALL"]] src_ip: Union[IPV4Address, Literal["ALL"]]
@@ -37,12 +35,10 @@ class ACLAddRuleAbstractAction(AbstractAction, ABC):
dst_wildcard: Union[IPV4Address, Literal["NONE"]] dst_wildcard: Union[IPV4Address, Literal["NONE"]]
class ACLRemoveRuleAbstractAction(AbstractAction, discriminator="acl-remove-rule-abstract-action"): class ACLRemoveRuleAbstractAction(AbstractAction, ABC):
"""Base abstract class for acl remove rule actions.""" """Base abstract class for acl remove rule actions."""
config: ConfigSchema = "ACLRemoveRuleAbstractAction.ConfigSchema" class ConfigSchema(AbstractAction.ConfigSchema, ABC):
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration Schema base for ACL remove rule abstract actions.""" """Configuration Schema base for ACL remove rule abstract actions."""
position: int position: int

View File

@@ -23,9 +23,7 @@ class NodeApplicationAbstractAction(AbstractAction, ABC):
inherit from this base class. inherit from this base class.
""" """
config: "NodeApplicationAbstractAction.ConfigSchema" class ConfigSchema(AbstractAction.ConfigSchema, ABC):
class ConfigSchema(AbstractAction.ConfigSchema):
"""Base Configuration schema for Node Application actions.""" """Base Configuration schema for Node Application actions."""
node_name: str node_name: str

View File

@@ -24,9 +24,7 @@ class NodeFileAbstractAction(AbstractAction, ABC):
only three parameters can inherit from this base class. only three parameters can inherit from this base class.
""" """
config: "NodeFileAbstractAction.ConfigSchema" class ConfigSchema(AbstractAction.ConfigSchema, ABC):
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration Schema for NodeFileAbstractAction.""" """Configuration Schema for NodeFileAbstractAction."""
node_name: str node_name: str

View File

@@ -22,9 +22,7 @@ class NodeFolderAbstractAction(AbstractAction, ABC):
this base class. this base class.
""" """
config: "NodeFolderAbstractAction.ConfigSchema" class ConfigSchema(AbstractAction.ConfigSchema, ABC):
class ConfigSchema(AbstractAction.ConfigSchema):
"""Base configuration schema for NodeFolder actions.""" """Base configuration schema for NodeFolder actions."""
node_name: str node_name: str

View File

@@ -16,9 +16,7 @@ class HostNICAbstractAction(AbstractAction, ABC):
base class. base class.
""" """
config: "HostNICAbstractAction.ConfigSchema" class ConfigSchema(AbstractAction.ConfigSchema, ABC):
class ConfigSchema(AbstractAction.ConfigSchema):
"""Base Configuration schema for HostNIC actions.""" """Base Configuration schema for HostNIC actions."""
node_name: str node_name: str

View File

@@ -1,5 +1,6 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from abc import ABC
from typing import ClassVar from typing import ClassVar
from primaite.game.agent.actions.manager import AbstractAction from primaite.game.agent.actions.manager import AbstractAction
@@ -8,12 +9,10 @@ from primaite.interface.request import RequestFormat
__all__ = ("NetworkPortEnableAction", "NetworkPortDisableAction") __all__ = ("NetworkPortEnableAction", "NetworkPortDisableAction")
class NetworkPortAbstractAction(AbstractAction, discriminator="network-port-abstract"): class NetworkPortAbstractAction(AbstractAction, ABC):
"""Base class for Network port actions.""" """Base class for Network port actions."""
config: "NetworkPortAbstractAction.ConfigSchema" class ConfigSchema(AbstractAction.ConfigSchema, ABC):
class ConfigSchema(AbstractAction.ConfigSchema):
"""Base configuration schema for NetworkPort actions.""" """Base configuration schema for NetworkPort actions."""
target_nodename: str target_nodename: str

View File

@@ -1,5 +1,5 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from abc import abstractmethod from abc import ABC, abstractmethod
from typing import ClassVar, List, Optional, Union from typing import ClassVar, List, Optional, Union
from primaite.game.agent.actions.manager import AbstractAction from primaite.game.agent.actions.manager import AbstractAction
@@ -18,16 +18,14 @@ __all__ = (
) )
class NodeAbstractAction(AbstractAction, discriminator="node-abstract"): class NodeAbstractAction(AbstractAction, ABC):
""" """
Abstract base class for node actions. Abstract base class for node actions.
Any action which applies to a node and uses node_name as its only parameter can inherit from this base class. Any action which applies to a node and uses node_name as its only parameter can inherit from this base class.
""" """
config: "NodeAbstractAction.ConfigSchema" class ConfigSchema(AbstractAction.ConfigSchema, ABC):
class ConfigSchema(AbstractAction.ConfigSchema):
"""Base Configuration schema for Node actions.""" """Base Configuration schema for Node actions."""
node_name: str node_name: str
@@ -83,12 +81,10 @@ class NodeResetAction(NodeAbstractAction, discriminator="node-reset"):
verb: ClassVar[str] = "reset" verb: ClassVar[str] = "reset"
class NodeNMAPAbstractAction(AbstractAction, discriminator="node-nmap-abstract-action"): class NodeNMAPAbstractAction(AbstractAction, ABC):
"""Base class for NodeNMAP actions.""" """Base class for NodeNMAP actions."""
config: "NodeNMAPAbstractAction.ConfigSchema" class ConfigSchema(AbstractAction.ConfigSchema, ABC):
class ConfigSchema(AbstractAction.ConfigSchema):
"""Base Configuration Schema for NodeNMAP actions.""" """Base Configuration Schema for NodeNMAP actions."""
target_ip_address: Union[str, List[str]] target_ip_address: Union[str, List[str]]

View File

@@ -1,4 +1,5 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from abc import ABC
from typing import ClassVar from typing import ClassVar
from primaite.game.agent.actions.manager import AbstractAction from primaite.game.agent.actions.manager import AbstractAction
@@ -17,15 +18,13 @@ __all__ = (
) )
class NodeServiceAbstractAction(AbstractAction, discriminator="node-service-abstract"): class NodeServiceAbstractAction(AbstractAction, ABC):
"""Abstract Action for Node Service related actions. """Abstract Action for Node Service related actions.
Any actions which use node_name and service_name can inherit from this class. Any actions which use node_name and service_name can inherit from this class.
""" """
config: "NodeServiceAbstractAction.ConfigSchema" class ConfigSchema(AbstractAction.ConfigSchema, ABC):
class ConfigSchema(AbstractAction.ConfigSchema):
node_name: str node_name: str
service_name: str service_name: str
verb: ClassVar[str] verb: ClassVar[str]

View File

@@ -1,5 +1,5 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from abc import abstractmethod from abc import ABC, abstractmethod
from primaite.game.agent.actions.manager import AbstractAction from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat from primaite.interface.request import RequestFormat
@@ -11,12 +11,10 @@ __all__ = (
) )
class NodeSessionAbstractAction(AbstractAction, discriminator="node-session-abstract"): class NodeSessionAbstractAction(AbstractAction, ABC):
"""Base class for NodeSession actions.""" """Base class for NodeSession actions."""
config: "NodeSessionAbstractAction.ConfigSchema" class ConfigSchema(AbstractAction.ConfigSchema, ABC):
class ConfigSchema(AbstractAction.ConfigSchema):
"""Base configuration schema for NodeSessionAbstractActions.""" """Base configuration schema for NodeSessionAbstractActions."""
node_name: str node_name: str

View File

@@ -68,7 +68,7 @@ class AbstractAgent(BaseModel, ABC):
) )
reward_function: RewardFunction.ConfigSchema = Field(default_factory=lambda: RewardFunction.ConfigSchema()) reward_function: RewardFunction.ConfigSchema = Field(default_factory=lambda: RewardFunction.ConfigSchema())
config: "AbstractAgent.ConfigSchema" = Field(default_factory=lambda: AbstractAgent.ConfigSchema()) config: ConfigSchema = Field(default_factory=lambda: AbstractAgent.ConfigSchema())
logger: AgentLog = AgentLog(agent_name="Abstract_Agent") logger: AgentLog = AgentLog(agent_name="Abstract_Agent")
history: List[AgentHistoryItem] = [] history: List[AgentHistoryItem] = []
@@ -161,16 +161,16 @@ class AbstractAgent(BaseModel, ABC):
return agent_class(config=config) return agent_class(config=config)
class AbstractScriptedAgent(AbstractAgent, discriminator="abstract-scripted-agent"): class AbstractScriptedAgent(AbstractAgent, ABC):
"""Base class for actors which generate their own behaviour.""" """Base class for actors which generate their own behaviour."""
config: "AbstractScriptedAgent.ConfigSchema" = Field(default_factory=lambda: AbstractScriptedAgent.ConfigSchema()) class ConfigSchema(AbstractAgent.ConfigSchema, ABC):
class ConfigSchema(AbstractAgent.ConfigSchema):
"""Configuration Schema for AbstractScriptedAgents.""" """Configuration Schema for AbstractScriptedAgents."""
type: str = "AbstractScriptedAgent" type: str = "AbstractScriptedAgent"
config: ConfigSchema = Field(default_factory=lambda: AbstractScriptedAgent.ConfigSchema())
@abstractmethod @abstractmethod
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""Return an action to be taken in the environment.""" """Return an action to be taken in the environment."""

View File

@@ -43,16 +43,16 @@ _LOGGER = getLogger(__name__)
WhereType = Optional[Iterable[Union[str, int]]] WhereType = Optional[Iterable[Union[str, int]]]
class AbstractReward(BaseModel): class AbstractReward(BaseModel, ABC):
"""Base class for reward function components.""" """Base class for reward function components."""
config: "AbstractReward.ConfigSchema"
class ConfigSchema(BaseModel, ABC): class ConfigSchema(BaseModel, ABC):
"""Config schema for AbstractReward.""" """Config schema for AbstractReward."""
type: str = "" type: str = ""
config: ConfigSchema
_registry: ClassVar[Dict[str, Type["AbstractReward"]]] = {} _registry: ClassVar[Dict[str, Type["AbstractReward"]]] = {}
def __init_subclass__(cls, discriminator: Optional[str] = None, **kwargs: Any) -> None: def __init_subclass__(cls, discriminator: Optional[str] = None, **kwargs: Any) -> None:

View File

@@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import random import random
from abc import abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from gymnasium.core import ObsType from gymnasium.core import ObsType
@@ -13,18 +13,18 @@ from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
__all__ = "AbstractTAPAgent" __all__ = "AbstractTAPAgent"
class AbstractTAPAgent(PeriodicAgent, discriminator="abstract-tap"): class AbstractTAPAgent(PeriodicAgent, ABC):
"""Base class for TAP agents to inherit from.""" """Base class for TAP agents to inherit from."""
config: "AbstractTAPAgent.ConfigSchema" = Field(default_factory=lambda: AbstractTAPAgent.ConfigSchema()) config: "AbstractTAPAgent.ConfigSchema" = Field(default_factory=lambda: AbstractTAPAgent.ConfigSchema())
next_execution_timestep: int = 0 next_execution_timestep: int = 0
class AgentSettingsSchema(PeriodicAgent.AgentSettingsSchema): class AgentSettingsSchema(PeriodicAgent.AgentSettingsSchema, ABC):
"""Schema for the `agent_settings` part of the agent config.""" """Schema for the `agent_settings` part of the agent config."""
possible_starting_nodes: List[str] = Field(default_factory=list) possible_starting_nodes: List[str] = Field(default_factory=list)
class ConfigSchema(PeriodicAgent.ConfigSchema): class ConfigSchema(PeriodicAgent.ConfigSchema, ABC):
"""Configuration schema for Abstract TAP agents.""" """Configuration schema for Abstract TAP agents."""
type: str = "abstract-tap" type: str = "abstract-tap"

View File

@@ -6,8 +6,6 @@ from pydantic import Field
from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
__all__ = "DataManipulationAgent"
class DataManipulationAgent(PeriodicAgent, discriminator="red-database-corrupting-agent"): class DataManipulationAgent(PeriodicAgent, discriminator="red-database-corrupting-agent"):
"""Agent that uses a DataManipulationBot to perform an SQL injection attack.""" """Agent that uses a DataManipulationBot to perform an SQL injection attack."""

View File

@@ -288,8 +288,8 @@ class PrimaiteGame:
if "folder_restore_duration" in defaults_config: if "folder_restore_duration" in defaults_config:
new_node.file_system._default_folder_restore_duration = defaults_config["folder_restore_duration"] new_node.file_system._default_folder_restore_duration = defaults_config["folder_restore_duration"]
if "users" in node_cfg and new_node.software_manager.software.get("UserManager"): if "users" in node_cfg and new_node.software_manager.software.get("user-manager"):
user_manager: UserManager = new_node.software_manager.software["UserManager"] # noqa user_manager: UserManager = new_node.software_manager.software["user-manager"] # noqa
for user_cfg in node_cfg["users"]: for user_cfg in node_cfg["users"]:
user_manager.add_user(**user_cfg, bypass_can_perform_action=True) user_manager.add_user(**user_cfg, bypass_can_perform_action=True)

View File

@@ -170,13 +170,13 @@ def arcd_uc2_network() -> Network:
client_1.power_on() client_1.power_on()
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
client_1.software_manager.install(DatabaseClient) client_1.software_manager.install(DatabaseClient)
db_client_1: DatabaseClient = client_1.software_manager.software.get("DatabaseClient") db_client_1: DatabaseClient = client_1.software_manager.software.get("database-client")
db_client_1.configure(server_ip_address=IPv4Address("192.168.1.14")) db_client_1.configure(server_ip_address=IPv4Address("192.168.1.14"))
db_client_1.run() db_client_1.run()
web_browser_1 = client_1.software_manager.software.get("web-browser") web_browser_1 = client_1.software_manager.software.get("web-browser")
web_browser_1.target_url = "http://arcd.com/users/" web_browser_1.target_url = "http://arcd.com/users/"
client_1.software_manager.install(DataManipulationBot) client_1.software_manager.install(DataManipulationBot)
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot") db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("data-manipulation-bot")
db_manipulation_bot.configure( db_manipulation_bot.configure(
server_ip_address=IPv4Address("192.168.1.14"), server_ip_address=IPv4Address("192.168.1.14"),
payload="DELETE", payload="DELETE",

View File

@@ -37,7 +37,7 @@ class DatabaseClientConnection(BaseModel):
@property @property
def client(self) -> Optional[DatabaseClient]: def client(self) -> Optional[DatabaseClient]:
"""The DatabaseClient that holds this connection.""" """The DatabaseClient that holds this connection."""
return self.parent_node.software_manager.software.get("DatabaseClient") return self.parent_node.software_manager.software.get("database-client")
def query(self, sql: str) -> bool: def query(self, sql: str) -> bool:
""" """

View File

@@ -70,7 +70,7 @@ class NMAP(Application, discriminator="nmap"):
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
kwargs["name"] = "NMAP" kwargs["name"] = "nmap"
kwargs["port"] = PORT_LOOKUP["NONE"] kwargs["port"] = PORT_LOOKUP["NONE"]
kwargs["protocol"] = PROTOCOL_LOOKUP["NONE"] kwargs["protocol"] = PROTOCOL_LOOKUP["NONE"]
super().__init__(**kwargs) super().__init__(**kwargs)

View File

@@ -678,8 +678,8 @@ def test_firewall_acl_add_remove_rule_integration():
assert firewall.external_outbound_acl.acl[1].action.name == "DENY" assert firewall.external_outbound_acl.acl[1].action.name == "DENY"
assert firewall.external_outbound_acl.acl[1].src_ip_address == IPv4Address("192.168.20.10") assert firewall.external_outbound_acl.acl[1].src_ip_address == IPv4Address("192.168.20.10")
assert firewall.external_outbound_acl.acl[1].dst_ip_address == IPv4Address("192.168.0.10") assert firewall.external_outbound_acl.acl[1].dst_ip_address == IPv4Address("192.168.0.10")
assert firewall.external_outbound_acl.acl[1].dst_port == PORT_LOOKUP["NONE"] assert firewall.external_outbound_acl.acl[1].dst_port is None
assert firewall.external_outbound_acl.acl[1].src_port == PORT_LOOKUP["NONE"] assert firewall.external_outbound_acl.acl[1].src_port is None
assert firewall.external_outbound_acl.acl[1].protocol == PROTOCOL_LOOKUP["NONE"] assert firewall.external_outbound_acl.acl[1].protocol == PROTOCOL_LOOKUP["NONE"]
env.step(12) # Remove ACL rule from External Outbound env.step(12) # Remove ACL rule from External Outbound

View File

@@ -27,7 +27,7 @@ class BroadcastTestService(Service, discriminator="broadcast-test-service"):
def __init__(self, **kwargs): def __init__(self, **kwargs):
# Set default service properties for broadcasting # Set default service properties for broadcasting
kwargs["name"] = "BroadcastService" kwargs["name"] = "broadcast-test-service"
kwargs["port"] = PORT_LOOKUP["HTTP"] kwargs["port"] = PORT_LOOKUP["HTTP"]
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
super().__init__(**kwargs) super().__init__(**kwargs)
@@ -127,7 +127,7 @@ def broadcast_network() -> Network:
server_1.power_on() server_1.power_on()
server_1.software_manager.install(BroadcastTestService) server_1.software_manager.install(BroadcastTestService)
service: BroadcastTestService = server_1.software_manager.software["BroadcastService"] service: BroadcastTestService = server_1.software_manager.software["broadcast-test-service"]
service.start() service.start()
switch_1: Switch = Switch.from_config( switch_1: Switch = Switch.from_config(
@@ -153,7 +153,7 @@ def broadcast_service_and_clients(
"broadcast-test-client" "broadcast-test-client"
] ]
service: BroadcastTestService = broadcast_network.get_node_by_hostname("server_1").software_manager.software[ service: BroadcastTestService = broadcast_network.get_node_by_hostname("server_1").software_manager.software[
"broadcast-service" "broadcast-test-service"
] ]
return service, client_1, client_2 return service, client_1, client_2

View File

@@ -22,7 +22,7 @@ class _DatabaseListener(Service, discriminator="database-listener"):
listen_on_ports: Set[int] = {PORT_LOOKUP["POSTGRES_SERVER"]} listen_on_ports: Set[int] = {PORT_LOOKUP["POSTGRES_SERVER"]}
config: "_DatabaseListener.ConfigSchema" = Field(default_factory=lambda: _DatabaseListener.ConfigSchema()) config: "_DatabaseListener.ConfigSchema" = Field(default_factory=lambda: _DatabaseListener.ConfigSchema())
name: str = "DatabaseListener" name: str = "database-listener"
protocol: str = PROTOCOL_LOOKUP["TCP"] protocol: str = PROTOCOL_LOOKUP["TCP"]
port: int = PORT_LOOKUP["NONE"] port: int = PORT_LOOKUP["NONE"]
listen_on_ports: Set[int] = {PORT_LOOKUP["POSTGRES_SERVER"]} listen_on_ports: Set[int] = {PORT_LOOKUP["POSTGRES_SERVER"]}
@@ -45,7 +45,7 @@ def test_http_listener(client_server):
server_db.start() server_db.start()
server.software_manager.install(_DatabaseListener) server.software_manager.install(_DatabaseListener)
server_db_listener: _DatabaseListener = server.software_manager.software["DatabaseListener"] server_db_listener: _DatabaseListener = server.software_manager.software["database-listener"]
server_db_listener.start() server_db_listener.start()
computer.software_manager.install(DatabaseClient) computer.software_manager.install(DatabaseClient)

View File

@@ -46,7 +46,7 @@ def test_create_web_client():
computer.power_on() computer.power_on()
# Web Browser should be pre-installed in computer # Web Browser should be pre-installed in computer
web_browser: WebBrowser = computer.software_manager.software.get("web-browser") web_browser: WebBrowser = computer.software_manager.software.get("web-browser")
assert web_browser.name is "web-browser" assert web_browser.name == "web-browser"
assert web_browser.port is PORT_LOOKUP["HTTP"] assert web_browser.port is PORT_LOOKUP["HTTP"]
assert web_browser.protocol is PROTOCOL_LOOKUP["TCP"] assert web_browser.protocol is PROTOCOL_LOOKUP["TCP"]

View File

@@ -29,7 +29,7 @@ def dns_client() -> Computer:
def test_create_dns_client(dns_client): def test_create_dns_client(dns_client):
assert dns_client is not None assert dns_client is not None
dns_client_service: DNSClient = dns_client.software_manager.software.get("dns-client") dns_client_service: DNSClient = dns_client.software_manager.software.get("dns-client")
assert dns_client_service.name is "dns-client" assert dns_client_service.name == "dns-client"
assert dns_client_service.port is PORT_LOOKUP["DNS"] assert dns_client_service.port is PORT_LOOKUP["DNS"]
assert dns_client_service.protocol is PROTOCOL_LOOKUP["TCP"] assert dns_client_service.protocol is PROTOCOL_LOOKUP["TCP"]

View File

@@ -33,7 +33,7 @@ def dns_server() -> Node:
def test_create_dns_server(dns_server): def test_create_dns_server(dns_server):
assert dns_server is not None assert dns_server is not None
dns_server_service: DNSServer = dns_server.software_manager.software.get("dns-server") dns_server_service: DNSServer = dns_server.software_manager.software.get("dns-server")
assert dns_server_service.name is "dns-server" assert dns_server_service.name == "dns-server"
assert dns_server_service.port is PORT_LOOKUP["DNS"] assert dns_server_service.port is PORT_LOOKUP["DNS"]
assert dns_server_service.protocol is PROTOCOL_LOOKUP["TCP"] assert dns_server_service.protocol is PROTOCOL_LOOKUP["TCP"]

View File

@@ -32,7 +32,7 @@ def ftp_client() -> Node:
def test_create_ftp_client(ftp_client): def test_create_ftp_client(ftp_client):
assert ftp_client is not None assert ftp_client is not None
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("ftp-client") ftp_client_service: FTPClient = ftp_client.software_manager.software.get("ftp-client")
assert ftp_client_service.name is "ftp-client" assert ftp_client_service.name == "ftp-client"
assert ftp_client_service.port is PORT_LOOKUP["FTP"] assert ftp_client_service.port is PORT_LOOKUP["FTP"]
assert ftp_client_service.protocol is PROTOCOL_LOOKUP["TCP"] assert ftp_client_service.protocol is PROTOCOL_LOOKUP["TCP"]

View File

@@ -31,7 +31,7 @@ def ftp_server() -> Node:
def test_create_ftp_server(ftp_server): def test_create_ftp_server(ftp_server):
assert ftp_server is not None assert ftp_server is not None
ftp_server_service: FTPServer = ftp_server.software_manager.software.get("ftp-server") ftp_server_service: FTPServer = ftp_server.software_manager.software.get("ftp-server")
assert ftp_server_service.name is "ftp-server" assert ftp_server_service.name == "ftp-server"
assert ftp_server_service.port is PORT_LOOKUP["FTP"] assert ftp_server_service.port is PORT_LOOKUP["FTP"]
assert ftp_server_service.protocol is PROTOCOL_LOOKUP["TCP"] assert ftp_server_service.protocol is PROTOCOL_LOOKUP["TCP"]

View File

@@ -364,7 +364,7 @@ def test_SSH_across_network():
pc_a = network.get_node_by_hostname("client_1") pc_a = network.get_node_by_hostname("client_1")
router_1 = network.get_node_by_hostname("router_1") router_1 = network.get_node_by_hostname("router_1")
terminal_a: Terminal = pc_a.software_manager.software.get("Terminal") terminal_a: Terminal = pc_a.software_manager.software.get("terminal")
router_1.acl.add_rule( router_1.acl.add_rule(
action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=21 action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=21

View File

@@ -34,9 +34,9 @@ def web_server() -> Server:
def test_create_web_server(web_server): def test_create_web_server(web_server):
assert web_server is not None assert web_server is not None
web_server_service: WebServer = web_server.software_manager.software.get("web-server") web_server_service: WebServer = web_server.software_manager.software.get("web-server")
assert web_server_service.name is "web-server" assert web_server_service.name == "web-server"
assert web_server_service.port is PORT_LOOKUP["HTTP"] assert web_server_service.port == PORT_LOOKUP["HTTP"]
assert web_server_service.protocol is PROTOCOL_LOOKUP["TCP"] assert web_server_service.protocol == PROTOCOL_LOOKUP["TCP"]
def test_handling_get_request_not_found_path(web_server): def test_handling_get_request_not_found_path(web_server):