Merge remote-tracking branch 'origin/feature/1812-traverse-actions-dict' into feature/1924-Agent-Interface

This commit is contained in:
Marek Wolan
2023-09-21 10:13:31 +01:00
16 changed files with 852 additions and 128 deletions

View File

@@ -6,7 +6,7 @@ from networkx import MultiGraph
from prettytable import MARKDOWN, PrettyTable
from primaite import getLogger
from primaite.simulator.core import Action, ActionManager, AllowAllValidator, SimComponent
from primaite.simulator.core import Action, ActionManager, SimComponent
from primaite.simulator.network.hardware.base import Link, NIC, Node, SwitchPort
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import Router
@@ -45,12 +45,12 @@ class Network(SimComponent):
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
self._node_action_manager = ActionManager()
am.add_action(
"node",
Action(
func=lambda request, context: self.nodes[request.pop(0)].apply_action(request, context),
validator=AllowAllValidator(),
func=self._node_action_manager
# func=lambda request, context: self.nodes[request.pop(0)].apply_action(request, context),
),
)
return am
@@ -184,7 +184,8 @@ class Network(SimComponent):
self._node_id_map[len(self.nodes)] = node
node.parent = self
self._nx_graph.add_node(node.hostname)
_LOGGER.debug(f"Added node {node.uuid} to Network {self.uuid}")
_LOGGER.info(f"Added node {node.uuid} to Network {self.uuid}")
self._node_action_manager.add_action(name=node.uuid, action=Action(func=node._action_manager))
def get_node_by_hostname(self, hostname: str) -> Optional[Node]:
"""
@@ -218,6 +219,7 @@ class Network(SimComponent):
break
node.parent = None
_LOGGER.info(f"Removed node {node.uuid} from network {self.uuid}")
self._node_action_manager.remove_action(name=node.uuid)
def connect(self, endpoint_a: Union[NIC, SwitchPort], endpoint_b: Union[NIC, SwitchPort], **kwargs) -> None:
"""

View File

@@ -12,7 +12,7 @@ from prettytable import MARKDOWN, PrettyTable
from primaite import getLogger
from primaite.exceptions import NetworkError
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.core import SimComponent
from primaite.simulator.core import Action, ActionManager, SimComponent
from primaite.simulator.domain.account import Account
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.protocols.arp import ARPEntry, ARPPacket
@@ -89,9 +89,9 @@ class NIC(SimComponent):
"The Maximum Transmission Unit (MTU) of the NIC in Bytes. Default is 1500 B"
wake_on_lan: bool = False
"Indicates if the NIC supports Wake-on-LAN functionality."
connected_node: Optional[Node] = None
_connected_node: Optional[Node] = None
"The Node to which the NIC is connected."
connected_link: Optional[Link] = None
_connected_link: Optional[Link] = None
"The Link to which the NIC is connected."
enabled: bool = False
"Indicates whether the NIC is enabled."
@@ -135,17 +135,23 @@ class NIC(SimComponent):
{
"ip_adress": str(self.ip_address),
"subnet_mask": str(self.subnet_mask),
"gateway": str(self.gateway),
"mac_address": self.mac_address,
"speed": self.speed,
"mtu": self.mtu,
"wake_on_lan": self.wake_on_lan,
"dns_servers": self.dns_servers,
"enabled": self.enabled,
}
)
return state
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
am.add_action("enable", Action(func=lambda request, context: self.enable()))
am.add_action("disable", Action(func=lambda request, context: self.disable()))
return am
@property
def ip_network(self) -> IPv4Network:
"""
@@ -159,21 +165,21 @@ class NIC(SimComponent):
"""Attempt to enable the NIC."""
if self.enabled:
return
if not self.connected_node:
if not self._connected_node:
_LOGGER.error(f"NIC {self} cannot be enabled as it is not connected to a Node")
return
if self.connected_node.operating_state != NodeOperatingState.ON:
self.connected_node.sys_log.error(f"NIC {self} cannot be enabled as the endpoint is not turned on")
if self._connected_node.operating_state != NodeOperatingState.ON:
self._connected_node.sys_log.error(f"NIC {self} cannot be enabled as the endpoint is not turned on")
return
if not self.connected_link:
if not self._connected_link:
_LOGGER.error(f"NIC {self} cannot be enabled as it is not connected to a Link")
return
self.enabled = True
self.connected_node.sys_log.info(f"NIC {self} enabled")
self.pcap = PacketCapture(hostname=self.connected_node.hostname, ip_address=self.ip_address)
if self.connected_link:
self.connected_link.endpoint_up()
self._connected_node.sys_log.info(f"NIC {self} enabled")
self.pcap = PacketCapture(hostname=self._connected_node.hostname, ip_address=self.ip_address)
if self._connected_link:
self._connected_link.endpoint_up()
def disable(self):
"""Disable the NIC."""
@@ -181,12 +187,12 @@ class NIC(SimComponent):
return
self.enabled = False
if self.connected_node:
self.connected_node.sys_log.info(f"NIC {self} disabled")
if self._connected_node:
self._connected_node.sys_log.info(f"NIC {self} disabled")
else:
_LOGGER.debug(f"NIC {self} disabled")
if self.connected_link:
self.connected_link.endpoint_down()
if self._connected_link:
self._connected_link.endpoint_down()
def connect_link(self, link: Link):
"""
@@ -195,26 +201,26 @@ class NIC(SimComponent):
:param link: The link to which the NIC is connected.
:type link: :class:`~primaite.simulator.network.transmission.physical_layer.Link`
"""
if self.connected_link:
if self._connected_link:
_LOGGER.error(f"Cannot connect Link to NIC ({self.mac_address}) as it already has a connection")
return
if self.connected_link == link:
if self._connected_link == link:
_LOGGER.error(f"Cannot connect Link to NIC ({self.mac_address}) as it is already connected")
return
# TODO: Inform the Node that a link has been connected
self.connected_link = link
self._connected_link = link
self.enable()
_LOGGER.debug(f"NIC {self} connected to Link {link}")
def disconnect_link(self):
"""Disconnect the NIC from the connected Link."""
if self.connected_link.endpoint_a == self:
self.connected_link.endpoint_a = None
if self.connected_link.endpoint_b == self:
self.connected_link.endpoint_b = None
self.connected_link = None
if self._connected_link.endpoint_a == self:
self._connected_link.endpoint_a = None
if self._connected_link.endpoint_b == self:
self._connected_link.endpoint_b = None
self._connected_link = None
def add_dns_server(self, ip_address: IPv4Address):
"""
@@ -244,7 +250,7 @@ class NIC(SimComponent):
if self.enabled:
frame.set_sent_timestamp()
self.pcap.capture(frame)
self.connected_link.transmit_frame(sender_nic=self, frame=frame)
self._connected_link.transmit_frame(sender_nic=self, frame=frame)
return True
# Cannot send Frame as the NIC is not enabled
return False
@@ -263,7 +269,7 @@ class NIC(SimComponent):
self.pcap.capture(frame)
# If this destination or is broadcast
if frame.ethernet.dst_mac_addr == self.mac_address or frame.ethernet.dst_mac_addr == "ff:ff:ff:ff:ff:ff":
self.connected_node.receive_frame(frame=frame, from_nic=self)
self._connected_node.receive_frame(frame=frame, from_nic=self)
return True
return False
@@ -288,9 +294,9 @@ class SwitchPort(SimComponent):
"The speed of the SwitchPort in Mbps. Default is 100 Mbps."
mtu: int = 1500
"The Maximum Transmission Unit (MTU) of the SwitchPort in Bytes. Default is 1500 B"
connected_node: Optional[Node] = None
_connected_node: Optional[Node] = None
"The Node to which the SwitchPort is connected."
connected_link: Optional[Link] = None
_connected_link: Optional[Link] = None
"The Link to which the SwitchPort is connected."
enabled: bool = False
"Indicates whether the SwitchPort is enabled."
@@ -327,31 +333,31 @@ class SwitchPort(SimComponent):
if self.enabled:
return
if not self.connected_node:
if not self._connected_node:
_LOGGER.error(f"SwitchPort {self} cannot be enabled as it is not connected to a Node")
return
if self.connected_node.operating_state != NodeOperatingState.ON:
self.connected_node.sys_log.info(f"SwitchPort {self} cannot be enabled as the endpoint is not turned on")
if self._connected_node.operating_state != NodeOperatingState.ON:
self._connected_node.sys_log.info(f"SwitchPort {self} cannot be enabled as the endpoint is not turned on")
return
self.enabled = True
self.connected_node.sys_log.info(f"SwitchPort {self} enabled")
self.pcap = PacketCapture(hostname=self.connected_node.hostname, switch_port_number=self.port_num)
if self.connected_link:
self.connected_link.endpoint_up()
self._connected_node.sys_log.info(f"SwitchPort {self} enabled")
self.pcap = PacketCapture(hostname=self._connected_node.hostname, switch_port_number=self.port_num)
if self._connected_link:
self._connected_link.endpoint_up()
def disable(self):
"""Disable the SwitchPort."""
if not self.enabled:
return
self.enabled = False
if self.connected_node:
self.connected_node.sys_log.info(f"SwitchPort {self} disabled")
if self._connected_node:
self._connected_node.sys_log.info(f"SwitchPort {self} disabled")
else:
_LOGGER.debug(f"SwitchPort {self} disabled")
if self.connected_link:
self.connected_link.endpoint_down()
if self._connected_link:
self._connected_link.endpoint_down()
def connect_link(self, link: Link):
"""
@@ -359,26 +365,26 @@ class SwitchPort(SimComponent):
:param link: The link to which the SwitchPort is connected.
"""
if self.connected_link:
if self._connected_link:
_LOGGER.error(f"Cannot connect link to SwitchPort {self.mac_address} as it already has a connection")
return
if self.connected_link == link:
if self._connected_link == link:
_LOGGER.error(f"Cannot connect Link to SwitchPort {self.mac_address} as it is already connected")
return
# TODO: Inform the Switch that a link has been connected
self.connected_link = link
self._connected_link = link
_LOGGER.debug(f"SwitchPort {self} connected to Link {link}")
self.enable()
def disconnect_link(self):
"""Disconnect the SwitchPort from the connected Link."""
if self.connected_link.endpoint_a == self:
self.connected_link.endpoint_a = None
if self.connected_link.endpoint_b == self:
self.connected_link.endpoint_b = None
self.connected_link = None
if self._connected_link.endpoint_a == self:
self._connected_link.endpoint_a = None
if self._connected_link.endpoint_b == self:
self._connected_link.endpoint_b = None
self._connected_link = None
def send_frame(self, frame: Frame) -> bool:
"""
@@ -388,7 +394,7 @@ class SwitchPort(SimComponent):
"""
if self.enabled:
self.pcap.capture(frame)
self.connected_link.transmit_frame(sender_nic=self, frame=frame)
self._connected_link.transmit_frame(sender_nic=self, frame=frame)
return True
# Cannot send Frame as the SwitchPort is not enabled
return False
@@ -404,7 +410,7 @@ class SwitchPort(SimComponent):
if self.enabled:
frame.decrement_ttl()
self.pcap.capture(frame)
connected_node: Node = self.connected_node
connected_node: Node = self._connected_node
connected_node.forward_frame(frame=frame, incoming_port=self)
return True
return False
@@ -937,6 +943,34 @@ class Node(SimComponent):
self.arp.nics = self.nics
self.session_manager.software_manager = self.software_manager
def _init_action_manager(self) -> ActionManager:
# TODO: I see that this code is really confusing and hard to read right now... I think some of these things will
# need a better name and better documentation.
am = super()._init_action_manager()
# since there are potentially many services, create an action manager that can map service name
self._service_action_manager = ActionManager()
am.add_action("service", Action(func=self._service_action_manager))
self._nic_action_manager = ActionManager()
am.add_action("nic", Action(func=self._nic_action_manager))
am.add_action("file_system", Action(func=self.file_system._action_manager))
# currently we don't have any applications nor processes, so these will be empty
self._process_action_manager = ActionManager()
am.add_action("process", Action(func=self._process_action_manager))
self._application_action_manager = ActionManager()
am.add_action("application", Action(func=self._application_action_manager))
am.add_action("scan", Action(func=lambda request, context: ...)) # TODO implement OS scan
am.add_action("shutdown", Action(func=lambda request, context: self.power_off()))
am.add_action("startup", Action(func=lambda request, context: self.power_on()))
am.add_action("reset", Action(func=lambda request, context: ...)) # TODO implement node reset
am.add_action("logon", Action(func=lambda request, context: ...)) # TODO implement logon action
am.add_action("logoff", Action(func=lambda request, context: ...)) # TODO implement logoff action
return am
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -1004,7 +1038,7 @@ class Node(SimComponent):
self.operating_state = NodeOperatingState.ON
self.sys_log.info("Turned on")
for nic in self.nics.values():
if nic.connected_link:
if nic._connected_link:
nic.enable()
def power_off(self):
@@ -1025,11 +1059,12 @@ class Node(SimComponent):
if nic.uuid not in self.nics:
self.nics[nic.uuid] = nic
self.ethernet_port[len(self.nics)] = nic
nic.connected_node = self
nic._connected_node = self
nic.parent = self
self.sys_log.info(f"Connected NIC {nic}")
if self.operating_state == NodeOperatingState.ON:
nic.enable()
self._nic_action_manager.add_action(nic.uuid, Action(func=nic._action_manager))
else:
msg = f"Cannot connect NIC {nic} as it is already connected"
self.sys_log.logger.error(msg)
@@ -1054,6 +1089,7 @@ class Node(SimComponent):
nic.parent = None
nic.disable()
self.sys_log.info(f"Disconnected NIC {nic}")
self._nic_action_manager.remove_action(nic.uuid)
else:
msg = f"Cannot disconnect NIC {nic} as it is not connected"
self.sys_log.logger.error(msg)
@@ -1150,7 +1186,8 @@ class Node(SimComponent):
service.parent = self
service.install() # Perform any additional setup, such as creating files for this service on the node.
self.sys_log.info(f"Installed service {service.name}")
_LOGGER.debug(f"Added service {service.uuid} to node {self.uuid}")
_LOGGER.info(f"Added service {service.uuid} to node {self.uuid}")
self._service_action_manager.add_action(service.uuid, Action(func=service._action_manager))
def uninstall_service(self, service: Service) -> None:
"""Uninstall and completely remove service from this node.
@@ -1165,7 +1202,8 @@ class Node(SimComponent):
self.services.pop(service.uuid)
service.parent = None
self.sys_log.info(f"Uninstalled service {service.name}")
_LOGGER.debug(f"Removed service {service.uuid} from node {self.uuid}")
_LOGGER.info(f"Removed service {service.uuid} from node {self.uuid}")
self._service_action_manager.remove_action(service.uuid)
def __contains__(self, item: Any) -> bool:
if isinstance(item, Service):
@@ -1188,7 +1226,7 @@ class Switch(Node):
if not self.switch_ports:
self.switch_ports = {i: SwitchPort() for i in range(1, self.num_ports + 1)}
for port_num, port in self.switch_ports.items():
port.connected_node = self
port._connected_node = self
port.parent = self
port.port_num = port_num
@@ -1261,7 +1299,7 @@ class Switch(Node):
_LOGGER.error(msg)
raise NetworkError(msg)
if port.connected_link != link:
if port._connected_link != link:
msg = f"The link does not match the connection at port number {port_number}"
_LOGGER.error(msg)
raise NetworkError(msg)

View File

@@ -7,7 +7,7 @@ from typing import Dict, List, Optional, Tuple, Union
from prettytable import MARKDOWN, PrettyTable
from primaite.simulator.core import SimComponent
from primaite.simulator.core import Action, ActionManager, SimComponent
from primaite.simulator.network.hardware.base import ARPCache, ICMP, NIC, Node
from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
from primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol
@@ -87,6 +87,36 @@ class AccessControlList(SimComponent):
super().__init__(**kwargs)
self._acl = [None] * (self.max_acl_rules - 1)
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
# When the request reaches this action, it should now contain solely positional args for the 'add_rule' action.
# POSITIONAL ARGUMENTS:
# 0: action (str name of an ACLAction)
# 1: protocol (str name of an IPProtocol)
# 2: source ip address (str castable to IPV4Address (e.g. '10.10.1.2'))
# 3: source port (str name of a Port (e.g. "HTTP")) # should we be using value, such as 80 or 443?
# 4: destination ip address (str castable to IPV4Address (e.g. '10.10.1.2'))
# 5: destination port (str name of a Port (e.g. "HTTP"))
# 6: position (int)
am.add_action(
"add_rule",
Action(
func=lambda request, context: self.add_rule(
ACLAction[request[0]],
IPProtocol[request[1]],
IPv4Address[request[2]],
Port[request[3]],
IPv4Address[request[4]],
Port[request[5]],
int(request[6]),
)
),
)
am.add_action("remove_rule", Action(func=lambda request, context: self.remove_rule(int(request[0]))))
return am
def describe_state(self) -> Dict:
"""
Describes the current state of the AccessControlList.
@@ -596,6 +626,11 @@ class Router(Node):
self.arp.nics = self.nics
self.icmp.arp = self.arp
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
am.add_action("acl", Action(func=self.acl._action_manager))
return am
def _get_port_of_nic(self, target_nic: NIC) -> Optional[int]:
"""
Retrieve the port number for a given NIC.

View File

@@ -30,7 +30,7 @@ class Switch(Node):
if not self.switch_ports:
self.switch_ports = {i: SwitchPort() for i in range(1, self.num_ports + 1)}
for port_num, port in self.switch_ports.items():
port.connected_node = self
port._connected_node = self
port.parent = self
port.port_num = port_num
@@ -113,7 +113,7 @@ class Switch(Node):
_LOGGER.error(msg)
raise NetworkError(msg)
if port.connected_link != link:
if port._connected_link != link:
msg = f"The link does not match the connection at port number {port_number}"
_LOGGER.error(msg)
raise NetworkError(msg)