Merge remote-tracking branch 'origin/feature/1812-traverse-actions-dict' into feature/1924-Agent-Interface

This commit is contained in:
Marek Wolan
2023-09-21 10:13:31 +01:00
16 changed files with 852 additions and 128 deletions

View File

@@ -1,6 +1,7 @@
# flake8: noqa
"""Core of the PrimAITE Simulator."""
from abc import ABC, abstractmethod
from typing import Callable, Dict, List, Optional
from typing import Callable, ClassVar, Dict, List, Optional, Union
from uuid import uuid4
from pydantic import BaseModel, ConfigDict
@@ -10,7 +11,7 @@ from primaite import getLogger
_LOGGER = getLogger(__name__)
class ActionPermissionValidator(ABC):
class ActionPermissionValidator(BaseModel):
"""
Base class for action validators.
@@ -33,7 +34,7 @@ class AllowAllValidator(ActionPermissionValidator):
return True
class Action:
class Action(BaseModel):
"""
This object stores data related to a single action.
@@ -41,34 +42,28 @@ class Action:
the action can be performed or not.
"""
def __init__(
self, func: Callable[[List[str], Dict], None], validator: ActionPermissionValidator = AllowAllValidator()
) -> None:
"""
Save the functions that are for this action.
func: Callable[[List[str], Dict], None]
"""
``func`` is a function that accepts a request and a context dict. Typically this would be a lambda function
that invokes a class method of your SimComponent. For example if the component is a node and the action is for
turning it off, then the SimComponent should have a turn_off(self) method that does not need to accept any args.
Then, this Action will be given something like ``func = lambda request, context: self.turn_off()``.
Here's a description for the intended use of both of these.
``func`` is a function that accepts a request and a context dict. Typically this would be a lambda function
that invokes a class method of your SimComponent. For example if the component is a node and the action is for
turning it off, then the SimComponent should have a turn_off(self) method that does not need to accept any args.
Then, this Action will be given something like ``func = lambda request, context: self.turn_off()``.
``validator`` is an instance of a subclass of `ActionPermissionValidator`. This is essentially a callable that
accepts `request` and `context` and returns a boolean to represent whether the permission is granted to perform
the action.
:param func: Function that performs the request.
:type func: Callable[[List[str], Dict], None]
:param validator: Function that checks if the request is authenticated given the context. By default, if no
validator is provided, an 'allow all' validator is added which permits all requests.
:type validator: ActionPermissionValidator
"""
self.func: Callable[[List[str], Dict], None] = func
self.validator: ActionPermissionValidator = validator
``func`` can also be another action manager, since ActionManager is a callable with a signature that matches what is
expected by ``func``.
"""
validator: ActionPermissionValidator = AllowAllValidator()
"""
``validator`` is an instance of `ActionPermissionValidator`. This is essentially a callable that
accepts `request` and `context` and returns a boolean to represent whether the permission is granted to perform
the action. The default validator will allow
"""
class ActionManager:
# TODO: maybe this can be renamed to something like action selector?
# Because there are two ways it's used, to select from a list of action verbs, or to select a child object to which to
# forward the request.
class ActionManager(BaseModel):
"""
ActionManager is used by `SimComponent` instances to keep track of actions.
@@ -76,12 +71,12 @@ class ActionManager:
class is responsible for providing a consistent API for processing actions as well as helpful error messages.
"""
def __init__(self) -> None:
"""Initialise ActionManager with an empty action lookup."""
self.actions: Dict[str, Action] = {}
actions: Dict[str, Action] = {}
"""maps action verb to an action object."""
def process_request(self, request: List[str], context: Dict) -> None:
"""Process an action request.
def __call__(self, request: Callable[[List[str], Dict], None], context: Dict) -> None:
"""
Process an action request.
:param request: A list of strings which specify what action to take. The first string must be one of the allowed
actions, i.e. it must be a key of self.actions. The subsequent strings in the list are passed as parameters
@@ -111,7 +106,8 @@ class ActionManager:
action.func(action_options, context)
def add_action(self, name: str, action: Action) -> None:
"""Add an action to this action manager.
"""
Add an action to this action manager.
:param name: The string associated to this action.
:type name: str
@@ -125,6 +121,32 @@ class ActionManager:
self.actions[name] = action
def remove_action(self, name: str) -> None:
"""
Remove an action from this manager.
:param name: name identifier of the action
:type name: str
"""
if name not in self.actions:
msg = f"Attempted to remove action {name} from action manager, but it was not registered."
_LOGGER.error(msg)
raise RuntimeError(msg)
self.actions.pop(name)
def get_action_tree(self) -> List[List[str]]:
"""Recursively generate action tree for this component."""
actions = []
for act_name, act in self.actions.items():
if isinstance(act.func, ActionManager):
sub_actions = act.func.get_action_tree()
sub_actions = [[act_name] + a for a in sub_actions]
actions.extend(sub_actions)
else:
actions.append([act_name])
return actions
class SimComponent(BaseModel):
"""Extension of pydantic BaseModel with additional methods that must be defined by all classes in the simulator."""
@@ -140,7 +162,7 @@ class SimComponent(BaseModel):
kwargs["uuid"] = str(uuid4())
super().__init__(**kwargs)
self._action_manager: ActionManager = self._init_action_manager()
self.parent: Optional["SimComponent"] = None
self._parent: Optional["SimComponent"] = None
def _init_action_manager(self) -> ActionManager:
"""
@@ -196,9 +218,9 @@ class SimComponent(BaseModel):
:param: context: Dict containing context for actions
:type context: Dict
"""
if self.action_manager is None:
if self._action_manager is None:
return
self.action_manager.process_request(action, context)
self._action_manager(action, context)
def apply_timestep(self, timestep: int) -> None:
"""
@@ -216,3 +238,20 @@ class SimComponent(BaseModel):
Override this method with anything that needs to happen within the component for it to be reset.
"""
pass
@property
def parent(self) -> "SimComponent":
"""Reference to the parent object which manages this object.
:return: Parent object.
:rtype: SimComponent
"""
return self._parent
@parent.setter
def parent(self, new_parent: Union["SimComponent", None]) -> None:
if self._parent and new_parent:
msg = f"Overwriting parent of {self.uuid}. Old parent: {self._parent.uuid}, New parent: {new_parent.uuid}"
_LOGGER.warn(msg)
raise RuntimeWarning(msg)
self._parent = new_parent

View File

@@ -46,13 +46,7 @@ class AccountGroup(Enum):
class GroupMembershipValidator(ActionPermissionValidator):
"""Permit actions based on group membership."""
def __init__(self, allowed_groups: List[AccountGroup]) -> None:
"""Store a list of groups that should be granted permission.
:param allowed_groups: List of AccountGroups that are permitted to perform some action.
:type allowed_groups: List[AccountGroup]
"""
self.allowed_groups = allowed_groups
allowed_groups:List[AccountGroup]
def __call__(self, request: List[str], context: Dict) -> bool:
"""Permit the action if the request comes from an account which belongs to the right group."""
@@ -93,7 +87,7 @@ class DomainController(SimComponent):
"account",
Action(
func=lambda request, context: self.accounts[request.pop(0)].apply_action(request, context),
validator=GroupMembershipValidator([AccountGroup.DOMAIN_ADMIN]),
validator=GroupMembershipValidator(allowed_groups=[AccountGroup.DOMAIN_ADMIN]),
),
)
return am

View File

@@ -10,7 +10,7 @@ from typing import Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from primaite import getLogger
from primaite.simulator.core import SimComponent
from primaite.simulator.core import Action, ActionManager, SimComponent
from primaite.simulator.file_system.file_type import FileType, get_file_type_from_extension
from primaite.simulator.system.core.sys_log import SysLog
@@ -100,6 +100,17 @@ class FileSystem(SimComponent):
if not self.folders:
self.create_folder("root")
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
self._folder_action_manager = ActionManager()
am.add_action("folder", Action(func=self._folder_action_manager))
self._file_action_manager = ActionManager()
am.add_action("file", Action(func=self._file_action_manager))
return am
@property
def size(self) -> int:
"""
@@ -160,6 +171,7 @@ class FileSystem(SimComponent):
self.folders[folder.uuid] = folder
self._folders_by_name[folder.name] = folder
self.sys_log.info(f"Created folder /{folder.name}")
self._folder_action_manager.add_action(folder.uuid, Action(func=folder._action_manager))
return folder
def delete_folder(self, folder_name: str):
@@ -178,6 +190,7 @@ class FileSystem(SimComponent):
self.folders.pop(folder.uuid)
self._folders_by_name.pop(folder.name)
self.sys_log.info(f"Deleted folder /{folder.name} and its contents")
self._folder_action_manager.remove_action(folder.uuid)
else:
_LOGGER.debug(f"Cannot delete folder as it does not exist: {folder_name}")
@@ -219,6 +232,7 @@ class FileSystem(SimComponent):
)
folder.add_file(file)
self.sys_log.info(f"Created file /{file.path}")
self._file_action_manager.add_action(file.uuid, Action(func=file._action_manager))
return file
def get_file(self, folder_name: str, file_name: str) -> Optional[File]:
@@ -246,6 +260,7 @@ class FileSystem(SimComponent):
file = folder.get_file(file_name)
if file:
folder.remove_file(file)
self._file_action_manager.remove_action(file.uuid)
self.sys_log.info(f"Deleted file /{file.path}")
def move_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str):
@@ -323,6 +338,18 @@ class Folder(FileSystemItemABC):
is_quarantined: bool = False
"Flag that marks the folder as quarantined if true."
def _init_action_manager(sekf) -> ActionManager:
am = super()._init_action_manager()
am.add_action("scan", Action(func=lambda request, context: ...)) # TODO implement action
am.add_action("checkhash", Action(func=lambda request, context: ...)) # TODO implement action
am.add_action("repair", Action(func=lambda request, context: ...)) # TODO implement action
am.add_action("restore", Action(func=lambda request, context: ...)) # TODO implement action
am.add_action("delete", Action(func=lambda request, context: ...)) # TODO implement action
am.add_action("corrupt", Action(func=lambda request, context: ...)) # TODO implement action
return am
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -488,6 +515,18 @@ class File(FileSystemItemABC):
with open(self.sim_path, mode="a"):
pass
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
am.add_action("scan", Action(func=lambda request, context: ...)) # TODO implement action
am.add_action("checkhash", Action(func=lambda request, context: ...)) # TODO implement action
am.add_action("delete", Action(func=lambda request, context: ...)) # TODO implement action
am.add_action("repair", Action(func=lambda request, context: ...)) # TODO implement action
am.add_action("restore", Action(func=lambda request, context: ...)) # TODO implement action
am.add_action("corrupt", Action(func=lambda request, context: ...)) # TODO implement action
return am
def make_copy(self, dst_folder: Folder) -> File:
"""
Create a copy of the current File object in the given destination folder.

View File

@@ -6,7 +6,7 @@ from networkx import MultiGraph
from prettytable import MARKDOWN, PrettyTable
from primaite import getLogger
from primaite.simulator.core import Action, ActionManager, AllowAllValidator, SimComponent
from primaite.simulator.core import Action, ActionManager, SimComponent
from primaite.simulator.network.hardware.base import Link, NIC, Node, SwitchPort
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import Router
@@ -45,12 +45,12 @@ class Network(SimComponent):
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
self._node_action_manager = ActionManager()
am.add_action(
"node",
Action(
func=lambda request, context: self.nodes[request.pop(0)].apply_action(request, context),
validator=AllowAllValidator(),
func=self._node_action_manager
# func=lambda request, context: self.nodes[request.pop(0)].apply_action(request, context),
),
)
return am
@@ -184,7 +184,8 @@ class Network(SimComponent):
self._node_id_map[len(self.nodes)] = node
node.parent = self
self._nx_graph.add_node(node.hostname)
_LOGGER.debug(f"Added node {node.uuid} to Network {self.uuid}")
_LOGGER.info(f"Added node {node.uuid} to Network {self.uuid}")
self._node_action_manager.add_action(name=node.uuid, action=Action(func=node._action_manager))
def get_node_by_hostname(self, hostname: str) -> Optional[Node]:
"""
@@ -218,6 +219,7 @@ class Network(SimComponent):
break
node.parent = None
_LOGGER.info(f"Removed node {node.uuid} from network {self.uuid}")
self._node_action_manager.remove_action(name=node.uuid)
def connect(self, endpoint_a: Union[NIC, SwitchPort], endpoint_b: Union[NIC, SwitchPort], **kwargs) -> None:
"""

View File

@@ -12,7 +12,7 @@ from prettytable import MARKDOWN, PrettyTable
from primaite import getLogger
from primaite.exceptions import NetworkError
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.core import SimComponent
from primaite.simulator.core import Action, ActionManager, SimComponent
from primaite.simulator.domain.account import Account
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.protocols.arp import ARPEntry, ARPPacket
@@ -89,9 +89,9 @@ class NIC(SimComponent):
"The Maximum Transmission Unit (MTU) of the NIC in Bytes. Default is 1500 B"
wake_on_lan: bool = False
"Indicates if the NIC supports Wake-on-LAN functionality."
connected_node: Optional[Node] = None
_connected_node: Optional[Node] = None
"The Node to which the NIC is connected."
connected_link: Optional[Link] = None
_connected_link: Optional[Link] = None
"The Link to which the NIC is connected."
enabled: bool = False
"Indicates whether the NIC is enabled."
@@ -135,17 +135,23 @@ class NIC(SimComponent):
{
"ip_adress": str(self.ip_address),
"subnet_mask": str(self.subnet_mask),
"gateway": str(self.gateway),
"mac_address": self.mac_address,
"speed": self.speed,
"mtu": self.mtu,
"wake_on_lan": self.wake_on_lan,
"dns_servers": self.dns_servers,
"enabled": self.enabled,
}
)
return state
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
am.add_action("enable", Action(func=lambda request, context: self.enable()))
am.add_action("disable", Action(func=lambda request, context: self.disable()))
return am
@property
def ip_network(self) -> IPv4Network:
"""
@@ -159,21 +165,21 @@ class NIC(SimComponent):
"""Attempt to enable the NIC."""
if self.enabled:
return
if not self.connected_node:
if not self._connected_node:
_LOGGER.error(f"NIC {self} cannot be enabled as it is not connected to a Node")
return
if self.connected_node.operating_state != NodeOperatingState.ON:
self.connected_node.sys_log.error(f"NIC {self} cannot be enabled as the endpoint is not turned on")
if self._connected_node.operating_state != NodeOperatingState.ON:
self._connected_node.sys_log.error(f"NIC {self} cannot be enabled as the endpoint is not turned on")
return
if not self.connected_link:
if not self._connected_link:
_LOGGER.error(f"NIC {self} cannot be enabled as it is not connected to a Link")
return
self.enabled = True
self.connected_node.sys_log.info(f"NIC {self} enabled")
self.pcap = PacketCapture(hostname=self.connected_node.hostname, ip_address=self.ip_address)
if self.connected_link:
self.connected_link.endpoint_up()
self._connected_node.sys_log.info(f"NIC {self} enabled")
self.pcap = PacketCapture(hostname=self._connected_node.hostname, ip_address=self.ip_address)
if self._connected_link:
self._connected_link.endpoint_up()
def disable(self):
"""Disable the NIC."""
@@ -181,12 +187,12 @@ class NIC(SimComponent):
return
self.enabled = False
if self.connected_node:
self.connected_node.sys_log.info(f"NIC {self} disabled")
if self._connected_node:
self._connected_node.sys_log.info(f"NIC {self} disabled")
else:
_LOGGER.debug(f"NIC {self} disabled")
if self.connected_link:
self.connected_link.endpoint_down()
if self._connected_link:
self._connected_link.endpoint_down()
def connect_link(self, link: Link):
"""
@@ -195,26 +201,26 @@ class NIC(SimComponent):
:param link: The link to which the NIC is connected.
:type link: :class:`~primaite.simulator.network.transmission.physical_layer.Link`
"""
if self.connected_link:
if self._connected_link:
_LOGGER.error(f"Cannot connect Link to NIC ({self.mac_address}) as it already has a connection")
return
if self.connected_link == link:
if self._connected_link == link:
_LOGGER.error(f"Cannot connect Link to NIC ({self.mac_address}) as it is already connected")
return
# TODO: Inform the Node that a link has been connected
self.connected_link = link
self._connected_link = link
self.enable()
_LOGGER.debug(f"NIC {self} connected to Link {link}")
def disconnect_link(self):
"""Disconnect the NIC from the connected Link."""
if self.connected_link.endpoint_a == self:
self.connected_link.endpoint_a = None
if self.connected_link.endpoint_b == self:
self.connected_link.endpoint_b = None
self.connected_link = None
if self._connected_link.endpoint_a == self:
self._connected_link.endpoint_a = None
if self._connected_link.endpoint_b == self:
self._connected_link.endpoint_b = None
self._connected_link = None
def add_dns_server(self, ip_address: IPv4Address):
"""
@@ -244,7 +250,7 @@ class NIC(SimComponent):
if self.enabled:
frame.set_sent_timestamp()
self.pcap.capture(frame)
self.connected_link.transmit_frame(sender_nic=self, frame=frame)
self._connected_link.transmit_frame(sender_nic=self, frame=frame)
return True
# Cannot send Frame as the NIC is not enabled
return False
@@ -263,7 +269,7 @@ class NIC(SimComponent):
self.pcap.capture(frame)
# If this destination or is broadcast
if frame.ethernet.dst_mac_addr == self.mac_address or frame.ethernet.dst_mac_addr == "ff:ff:ff:ff:ff:ff":
self.connected_node.receive_frame(frame=frame, from_nic=self)
self._connected_node.receive_frame(frame=frame, from_nic=self)
return True
return False
@@ -288,9 +294,9 @@ class SwitchPort(SimComponent):
"The speed of the SwitchPort in Mbps. Default is 100 Mbps."
mtu: int = 1500
"The Maximum Transmission Unit (MTU) of the SwitchPort in Bytes. Default is 1500 B"
connected_node: Optional[Node] = None
_connected_node: Optional[Node] = None
"The Node to which the SwitchPort is connected."
connected_link: Optional[Link] = None
_connected_link: Optional[Link] = None
"The Link to which the SwitchPort is connected."
enabled: bool = False
"Indicates whether the SwitchPort is enabled."
@@ -327,31 +333,31 @@ class SwitchPort(SimComponent):
if self.enabled:
return
if not self.connected_node:
if not self._connected_node:
_LOGGER.error(f"SwitchPort {self} cannot be enabled as it is not connected to a Node")
return
if self.connected_node.operating_state != NodeOperatingState.ON:
self.connected_node.sys_log.info(f"SwitchPort {self} cannot be enabled as the endpoint is not turned on")
if self._connected_node.operating_state != NodeOperatingState.ON:
self._connected_node.sys_log.info(f"SwitchPort {self} cannot be enabled as the endpoint is not turned on")
return
self.enabled = True
self.connected_node.sys_log.info(f"SwitchPort {self} enabled")
self.pcap = PacketCapture(hostname=self.connected_node.hostname, switch_port_number=self.port_num)
if self.connected_link:
self.connected_link.endpoint_up()
self._connected_node.sys_log.info(f"SwitchPort {self} enabled")
self.pcap = PacketCapture(hostname=self._connected_node.hostname, switch_port_number=self.port_num)
if self._connected_link:
self._connected_link.endpoint_up()
def disable(self):
"""Disable the SwitchPort."""
if not self.enabled:
return
self.enabled = False
if self.connected_node:
self.connected_node.sys_log.info(f"SwitchPort {self} disabled")
if self._connected_node:
self._connected_node.sys_log.info(f"SwitchPort {self} disabled")
else:
_LOGGER.debug(f"SwitchPort {self} disabled")
if self.connected_link:
self.connected_link.endpoint_down()
if self._connected_link:
self._connected_link.endpoint_down()
def connect_link(self, link: Link):
"""
@@ -359,26 +365,26 @@ class SwitchPort(SimComponent):
:param link: The link to which the SwitchPort is connected.
"""
if self.connected_link:
if self._connected_link:
_LOGGER.error(f"Cannot connect link to SwitchPort {self.mac_address} as it already has a connection")
return
if self.connected_link == link:
if self._connected_link == link:
_LOGGER.error(f"Cannot connect Link to SwitchPort {self.mac_address} as it is already connected")
return
# TODO: Inform the Switch that a link has been connected
self.connected_link = link
self._connected_link = link
_LOGGER.debug(f"SwitchPort {self} connected to Link {link}")
self.enable()
def disconnect_link(self):
"""Disconnect the SwitchPort from the connected Link."""
if self.connected_link.endpoint_a == self:
self.connected_link.endpoint_a = None
if self.connected_link.endpoint_b == self:
self.connected_link.endpoint_b = None
self.connected_link = None
if self._connected_link.endpoint_a == self:
self._connected_link.endpoint_a = None
if self._connected_link.endpoint_b == self:
self._connected_link.endpoint_b = None
self._connected_link = None
def send_frame(self, frame: Frame) -> bool:
"""
@@ -388,7 +394,7 @@ class SwitchPort(SimComponent):
"""
if self.enabled:
self.pcap.capture(frame)
self.connected_link.transmit_frame(sender_nic=self, frame=frame)
self._connected_link.transmit_frame(sender_nic=self, frame=frame)
return True
# Cannot send Frame as the SwitchPort is not enabled
return False
@@ -404,7 +410,7 @@ class SwitchPort(SimComponent):
if self.enabled:
frame.decrement_ttl()
self.pcap.capture(frame)
connected_node: Node = self.connected_node
connected_node: Node = self._connected_node
connected_node.forward_frame(frame=frame, incoming_port=self)
return True
return False
@@ -937,6 +943,34 @@ class Node(SimComponent):
self.arp.nics = self.nics
self.session_manager.software_manager = self.software_manager
def _init_action_manager(self) -> ActionManager:
# TODO: I see that this code is really confusing and hard to read right now... I think some of these things will
# need a better name and better documentation.
am = super()._init_action_manager()
# since there are potentially many services, create an action manager that can map service name
self._service_action_manager = ActionManager()
am.add_action("service", Action(func=self._service_action_manager))
self._nic_action_manager = ActionManager()
am.add_action("nic", Action(func=self._nic_action_manager))
am.add_action("file_system", Action(func=self.file_system._action_manager))
# currently we don't have any applications nor processes, so these will be empty
self._process_action_manager = ActionManager()
am.add_action("process", Action(func=self._process_action_manager))
self._application_action_manager = ActionManager()
am.add_action("application", Action(func=self._application_action_manager))
am.add_action("scan", Action(func=lambda request, context: ...)) # TODO implement OS scan
am.add_action("shutdown", Action(func=lambda request, context: self.power_off()))
am.add_action("startup", Action(func=lambda request, context: self.power_on()))
am.add_action("reset", Action(func=lambda request, context: ...)) # TODO implement node reset
am.add_action("logon", Action(func=lambda request, context: ...)) # TODO implement logon action
am.add_action("logoff", Action(func=lambda request, context: ...)) # TODO implement logoff action
return am
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -1004,7 +1038,7 @@ class Node(SimComponent):
self.operating_state = NodeOperatingState.ON
self.sys_log.info("Turned on")
for nic in self.nics.values():
if nic.connected_link:
if nic._connected_link:
nic.enable()
def power_off(self):
@@ -1025,11 +1059,12 @@ class Node(SimComponent):
if nic.uuid not in self.nics:
self.nics[nic.uuid] = nic
self.ethernet_port[len(self.nics)] = nic
nic.connected_node = self
nic._connected_node = self
nic.parent = self
self.sys_log.info(f"Connected NIC {nic}")
if self.operating_state == NodeOperatingState.ON:
nic.enable()
self._nic_action_manager.add_action(nic.uuid, Action(func=nic._action_manager))
else:
msg = f"Cannot connect NIC {nic} as it is already connected"
self.sys_log.logger.error(msg)
@@ -1054,6 +1089,7 @@ class Node(SimComponent):
nic.parent = None
nic.disable()
self.sys_log.info(f"Disconnected NIC {nic}")
self._nic_action_manager.remove_action(nic.uuid)
else:
msg = f"Cannot disconnect NIC {nic} as it is not connected"
self.sys_log.logger.error(msg)
@@ -1150,7 +1186,8 @@ class Node(SimComponent):
service.parent = self
service.install() # Perform any additional setup, such as creating files for this service on the node.
self.sys_log.info(f"Installed service {service.name}")
_LOGGER.debug(f"Added service {service.uuid} to node {self.uuid}")
_LOGGER.info(f"Added service {service.uuid} to node {self.uuid}")
self._service_action_manager.add_action(service.uuid, Action(func=service._action_manager))
def uninstall_service(self, service: Service) -> None:
"""Uninstall and completely remove service from this node.
@@ -1165,7 +1202,8 @@ class Node(SimComponent):
self.services.pop(service.uuid)
service.parent = None
self.sys_log.info(f"Uninstalled service {service.name}")
_LOGGER.debug(f"Removed service {service.uuid} from node {self.uuid}")
_LOGGER.info(f"Removed service {service.uuid} from node {self.uuid}")
self._service_action_manager.remove_action(service.uuid)
def __contains__(self, item: Any) -> bool:
if isinstance(item, Service):
@@ -1188,7 +1226,7 @@ class Switch(Node):
if not self.switch_ports:
self.switch_ports = {i: SwitchPort() for i in range(1, self.num_ports + 1)}
for port_num, port in self.switch_ports.items():
port.connected_node = self
port._connected_node = self
port.parent = self
port.port_num = port_num
@@ -1261,7 +1299,7 @@ class Switch(Node):
_LOGGER.error(msg)
raise NetworkError(msg)
if port.connected_link != link:
if port._connected_link != link:
msg = f"The link does not match the connection at port number {port_number}"
_LOGGER.error(msg)
raise NetworkError(msg)

View File

@@ -7,7 +7,7 @@ from typing import Dict, List, Optional, Tuple, Union
from prettytable import MARKDOWN, PrettyTable
from primaite.simulator.core import SimComponent
from primaite.simulator.core import Action, ActionManager, SimComponent
from primaite.simulator.network.hardware.base import ARPCache, ICMP, NIC, Node
from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
from primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol
@@ -87,6 +87,36 @@ class AccessControlList(SimComponent):
super().__init__(**kwargs)
self._acl = [None] * (self.max_acl_rules - 1)
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
# When the request reaches this action, it should now contain solely positional args for the 'add_rule' action.
# POSITIONAL ARGUMENTS:
# 0: action (str name of an ACLAction)
# 1: protocol (str name of an IPProtocol)
# 2: source ip address (str castable to IPV4Address (e.g. '10.10.1.2'))
# 3: source port (str name of a Port (e.g. "HTTP")) # should we be using value, such as 80 or 443?
# 4: destination ip address (str castable to IPV4Address (e.g. '10.10.1.2'))
# 5: destination port (str name of a Port (e.g. "HTTP"))
# 6: position (int)
am.add_action(
"add_rule",
Action(
func=lambda request, context: self.add_rule(
ACLAction[request[0]],
IPProtocol[request[1]],
IPv4Address[request[2]],
Port[request[3]],
IPv4Address[request[4]],
Port[request[5]],
int(request[6]),
)
),
)
am.add_action("remove_rule", Action(func=lambda request, context: self.remove_rule(int(request[0]))))
return am
def describe_state(self) -> Dict:
"""
Describes the current state of the AccessControlList.
@@ -596,6 +626,11 @@ class Router(Node):
self.arp.nics = self.nics
self.icmp.arp = self.arp
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
am.add_action("acl", Action(func=self.acl._action_manager))
return am
def _get_port_of_nic(self, target_nic: NIC) -> Optional[int]:
"""
Retrieve the port number for a given NIC.

View File

@@ -30,7 +30,7 @@ class Switch(Node):
if not self.switch_ports:
self.switch_ports = {i: SwitchPort() for i in range(1, self.num_ports + 1)}
for port_num, port in self.switch_ports.items():
port.connected_node = self
port._connected_node = self
port.parent = self
port.port_num = port_num
@@ -113,7 +113,7 @@ class Switch(Node):
_LOGGER.error(msg)
raise NetworkError(msg)
if port.connected_link != link:
if port._connected_link != link:
msg = f"The link does not match the connection at port number {port_number}"
_LOGGER.error(msg)
raise NetworkError(msg)

View File

@@ -24,19 +24,9 @@ class Simulation(SimComponent):
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
# pass through network actions to the network objects
am.add_action(
"network",
Action(
func=lambda request, context: self.network.apply_action(request, context), validator=AllowAllValidator()
),
)
am.add_action("network", Action(func=self.network._action_manager))
# pass through domain actions to the domain object
am.add_action(
"domain",
Action(
func=lambda request, context: self.domain.apply_action(request, context), validator=AllowAllValidator()
),
)
am.add_action("domain", Action(func=self.domain._action_manager))
return am
def describe_state(self) -> Dict:

View File

@@ -204,7 +204,7 @@ class IOSoftware(Software):
"max_sessions": self.max_sessions,
"tcp": self.tcp,
"udp": self.udp,
"ports": [port.name for port in self.ports], # TODO: not sure if this should be port.name or port.value
"port": self.port.value,
}
)
return state