Merge remote-tracking branch 'origin/feature/1812-traverse-actions-dict' into feature/1924-Agent-Interface
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# flake8: noqa
|
||||
"""Core of the PrimAITE Simulator."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Dict, List, Optional
|
||||
from typing import Callable, ClassVar, Dict, List, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
@@ -10,7 +11,7 @@ from primaite import getLogger
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class ActionPermissionValidator(ABC):
|
||||
class ActionPermissionValidator(BaseModel):
|
||||
"""
|
||||
Base class for action validators.
|
||||
|
||||
@@ -33,7 +34,7 @@ class AllowAllValidator(ActionPermissionValidator):
|
||||
return True
|
||||
|
||||
|
||||
class Action:
|
||||
class Action(BaseModel):
|
||||
"""
|
||||
This object stores data related to a single action.
|
||||
|
||||
@@ -41,34 +42,28 @@ class Action:
|
||||
the action can be performed or not.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, func: Callable[[List[str], Dict], None], validator: ActionPermissionValidator = AllowAllValidator()
|
||||
) -> None:
|
||||
"""
|
||||
Save the functions that are for this action.
|
||||
func: Callable[[List[str], Dict], None]
|
||||
"""
|
||||
``func`` is a function that accepts a request and a context dict. Typically this would be a lambda function
|
||||
that invokes a class method of your SimComponent. For example if the component is a node and the action is for
|
||||
turning it off, then the SimComponent should have a turn_off(self) method that does not need to accept any args.
|
||||
Then, this Action will be given something like ``func = lambda request, context: self.turn_off()``.
|
||||
|
||||
Here's a description for the intended use of both of these.
|
||||
|
||||
``func`` is a function that accepts a request and a context dict. Typically this would be a lambda function
|
||||
that invokes a class method of your SimComponent. For example if the component is a node and the action is for
|
||||
turning it off, then the SimComponent should have a turn_off(self) method that does not need to accept any args.
|
||||
Then, this Action will be given something like ``func = lambda request, context: self.turn_off()``.
|
||||
|
||||
``validator`` is an instance of a subclass of `ActionPermissionValidator`. This is essentially a callable that
|
||||
accepts `request` and `context` and returns a boolean to represent whether the permission is granted to perform
|
||||
the action.
|
||||
|
||||
:param func: Function that performs the request.
|
||||
:type func: Callable[[List[str], Dict], None]
|
||||
:param validator: Function that checks if the request is authenticated given the context. By default, if no
|
||||
validator is provided, an 'allow all' validator is added which permits all requests.
|
||||
:type validator: ActionPermissionValidator
|
||||
"""
|
||||
self.func: Callable[[List[str], Dict], None] = func
|
||||
self.validator: ActionPermissionValidator = validator
|
||||
``func`` can also be another action manager, since ActionManager is a callable with a signature that matches what is
|
||||
expected by ``func``.
|
||||
"""
|
||||
validator: ActionPermissionValidator = AllowAllValidator()
|
||||
"""
|
||||
``validator`` is an instance of `ActionPermissionValidator`. This is essentially a callable that
|
||||
accepts `request` and `context` and returns a boolean to represent whether the permission is granted to perform
|
||||
the action. The default validator will allow
|
||||
"""
|
||||
|
||||
|
||||
class ActionManager:
|
||||
# TODO: maybe this can be renamed to something like action selector?
|
||||
# Because there are two ways it's used, to select from a list of action verbs, or to select a child object to which to
|
||||
# forward the request.
|
||||
class ActionManager(BaseModel):
|
||||
"""
|
||||
ActionManager is used by `SimComponent` instances to keep track of actions.
|
||||
|
||||
@@ -76,12 +71,12 @@ class ActionManager:
|
||||
class is responsible for providing a consistent API for processing actions as well as helpful error messages.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialise ActionManager with an empty action lookup."""
|
||||
self.actions: Dict[str, Action] = {}
|
||||
actions: Dict[str, Action] = {}
|
||||
"""maps action verb to an action object."""
|
||||
|
||||
def process_request(self, request: List[str], context: Dict) -> None:
|
||||
"""Process an action request.
|
||||
def __call__(self, request: Callable[[List[str], Dict], None], context: Dict) -> None:
|
||||
"""
|
||||
Process an action request.
|
||||
|
||||
:param request: A list of strings which specify what action to take. The first string must be one of the allowed
|
||||
actions, i.e. it must be a key of self.actions. The subsequent strings in the list are passed as parameters
|
||||
@@ -111,7 +106,8 @@ class ActionManager:
|
||||
action.func(action_options, context)
|
||||
|
||||
def add_action(self, name: str, action: Action) -> None:
|
||||
"""Add an action to this action manager.
|
||||
"""
|
||||
Add an action to this action manager.
|
||||
|
||||
:param name: The string associated to this action.
|
||||
:type name: str
|
||||
@@ -125,6 +121,32 @@ class ActionManager:
|
||||
|
||||
self.actions[name] = action
|
||||
|
||||
def remove_action(self, name: str) -> None:
|
||||
"""
|
||||
Remove an action from this manager.
|
||||
|
||||
:param name: name identifier of the action
|
||||
:type name: str
|
||||
"""
|
||||
if name not in self.actions:
|
||||
msg = f"Attempted to remove action {name} from action manager, but it was not registered."
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
self.actions.pop(name)
|
||||
|
||||
def get_action_tree(self) -> List[List[str]]:
|
||||
"""Recursively generate action tree for this component."""
|
||||
actions = []
|
||||
for act_name, act in self.actions.items():
|
||||
if isinstance(act.func, ActionManager):
|
||||
sub_actions = act.func.get_action_tree()
|
||||
sub_actions = [[act_name] + a for a in sub_actions]
|
||||
actions.extend(sub_actions)
|
||||
else:
|
||||
actions.append([act_name])
|
||||
return actions
|
||||
|
||||
|
||||
class SimComponent(BaseModel):
|
||||
"""Extension of pydantic BaseModel with additional methods that must be defined by all classes in the simulator."""
|
||||
@@ -140,7 +162,7 @@ class SimComponent(BaseModel):
|
||||
kwargs["uuid"] = str(uuid4())
|
||||
super().__init__(**kwargs)
|
||||
self._action_manager: ActionManager = self._init_action_manager()
|
||||
self.parent: Optional["SimComponent"] = None
|
||||
self._parent: Optional["SimComponent"] = None
|
||||
|
||||
def _init_action_manager(self) -> ActionManager:
|
||||
"""
|
||||
@@ -196,9 +218,9 @@ class SimComponent(BaseModel):
|
||||
:param: context: Dict containing context for actions
|
||||
:type context: Dict
|
||||
"""
|
||||
if self.action_manager is None:
|
||||
if self._action_manager is None:
|
||||
return
|
||||
self.action_manager.process_request(action, context)
|
||||
self._action_manager(action, context)
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
@@ -216,3 +238,20 @@ class SimComponent(BaseModel):
|
||||
Override this method with anything that needs to happen within the component for it to be reset.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def parent(self) -> "SimComponent":
|
||||
"""Reference to the parent object which manages this object.
|
||||
|
||||
:return: Parent object.
|
||||
:rtype: SimComponent
|
||||
"""
|
||||
return self._parent
|
||||
|
||||
@parent.setter
|
||||
def parent(self, new_parent: Union["SimComponent", None]) -> None:
|
||||
if self._parent and new_parent:
|
||||
msg = f"Overwriting parent of {self.uuid}. Old parent: {self._parent.uuid}, New parent: {new_parent.uuid}"
|
||||
_LOGGER.warn(msg)
|
||||
raise RuntimeWarning(msg)
|
||||
self._parent = new_parent
|
||||
|
||||
@@ -46,13 +46,7 @@ class AccountGroup(Enum):
|
||||
class GroupMembershipValidator(ActionPermissionValidator):
|
||||
"""Permit actions based on group membership."""
|
||||
|
||||
def __init__(self, allowed_groups: List[AccountGroup]) -> None:
|
||||
"""Store a list of groups that should be granted permission.
|
||||
|
||||
:param allowed_groups: List of AccountGroups that are permitted to perform some action.
|
||||
:type allowed_groups: List[AccountGroup]
|
||||
"""
|
||||
self.allowed_groups = allowed_groups
|
||||
allowed_groups:List[AccountGroup]
|
||||
|
||||
def __call__(self, request: List[str], context: Dict) -> bool:
|
||||
"""Permit the action if the request comes from an account which belongs to the right group."""
|
||||
@@ -93,7 +87,7 @@ class DomainController(SimComponent):
|
||||
"account",
|
||||
Action(
|
||||
func=lambda request, context: self.accounts[request.pop(0)].apply_action(request, context),
|
||||
validator=GroupMembershipValidator([AccountGroup.DOMAIN_ADMIN]),
|
||||
validator=GroupMembershipValidator(allowed_groups=[AccountGroup.DOMAIN_ADMIN]),
|
||||
),
|
||||
)
|
||||
return am
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import Dict, Optional
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.core import SimComponent
|
||||
from primaite.simulator.core import Action, ActionManager, SimComponent
|
||||
from primaite.simulator.file_system.file_type import FileType, get_file_type_from_extension
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
|
||||
@@ -100,6 +100,17 @@ class FileSystem(SimComponent):
|
||||
if not self.folders:
|
||||
self.create_folder("root")
|
||||
|
||||
def _init_action_manager(self) -> ActionManager:
|
||||
am = super()._init_action_manager()
|
||||
|
||||
self._folder_action_manager = ActionManager()
|
||||
am.add_action("folder", Action(func=self._folder_action_manager))
|
||||
|
||||
self._file_action_manager = ActionManager()
|
||||
am.add_action("file", Action(func=self._file_action_manager))
|
||||
|
||||
return am
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""
|
||||
@@ -160,6 +171,7 @@ class FileSystem(SimComponent):
|
||||
self.folders[folder.uuid] = folder
|
||||
self._folders_by_name[folder.name] = folder
|
||||
self.sys_log.info(f"Created folder /{folder.name}")
|
||||
self._folder_action_manager.add_action(folder.uuid, Action(func=folder._action_manager))
|
||||
return folder
|
||||
|
||||
def delete_folder(self, folder_name: str):
|
||||
@@ -178,6 +190,7 @@ class FileSystem(SimComponent):
|
||||
self.folders.pop(folder.uuid)
|
||||
self._folders_by_name.pop(folder.name)
|
||||
self.sys_log.info(f"Deleted folder /{folder.name} and its contents")
|
||||
self._folder_action_manager.remove_action(folder.uuid)
|
||||
else:
|
||||
_LOGGER.debug(f"Cannot delete folder as it does not exist: {folder_name}")
|
||||
|
||||
@@ -219,6 +232,7 @@ class FileSystem(SimComponent):
|
||||
)
|
||||
folder.add_file(file)
|
||||
self.sys_log.info(f"Created file /{file.path}")
|
||||
self._file_action_manager.add_action(file.uuid, Action(func=file._action_manager))
|
||||
return file
|
||||
|
||||
def get_file(self, folder_name: str, file_name: str) -> Optional[File]:
|
||||
@@ -246,6 +260,7 @@ class FileSystem(SimComponent):
|
||||
file = folder.get_file(file_name)
|
||||
if file:
|
||||
folder.remove_file(file)
|
||||
self._file_action_manager.remove_action(file.uuid)
|
||||
self.sys_log.info(f"Deleted file /{file.path}")
|
||||
|
||||
def move_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str):
|
||||
@@ -323,6 +338,18 @@ class Folder(FileSystemItemABC):
|
||||
is_quarantined: bool = False
|
||||
"Flag that marks the folder as quarantined if true."
|
||||
|
||||
def _init_action_manager(sekf) -> ActionManager:
|
||||
am = super()._init_action_manager()
|
||||
|
||||
am.add_action("scan", Action(func=lambda request, context: ...)) # TODO implement action
|
||||
am.add_action("checkhash", Action(func=lambda request, context: ...)) # TODO implement action
|
||||
am.add_action("repair", Action(func=lambda request, context: ...)) # TODO implement action
|
||||
am.add_action("restore", Action(func=lambda request, context: ...)) # TODO implement action
|
||||
am.add_action("delete", Action(func=lambda request, context: ...)) # TODO implement action
|
||||
am.add_action("corrupt", Action(func=lambda request, context: ...)) # TODO implement action
|
||||
|
||||
return am
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
@@ -488,6 +515,18 @@ class File(FileSystemItemABC):
|
||||
with open(self.sim_path, mode="a"):
|
||||
pass
|
||||
|
||||
def _init_action_manager(self) -> ActionManager:
|
||||
am = super()._init_action_manager()
|
||||
|
||||
am.add_action("scan", Action(func=lambda request, context: ...)) # TODO implement action
|
||||
am.add_action("checkhash", Action(func=lambda request, context: ...)) # TODO implement action
|
||||
am.add_action("delete", Action(func=lambda request, context: ...)) # TODO implement action
|
||||
am.add_action("repair", Action(func=lambda request, context: ...)) # TODO implement action
|
||||
am.add_action("restore", Action(func=lambda request, context: ...)) # TODO implement action
|
||||
am.add_action("corrupt", Action(func=lambda request, context: ...)) # TODO implement action
|
||||
|
||||
return am
|
||||
|
||||
def make_copy(self, dst_folder: Folder) -> File:
|
||||
"""
|
||||
Create a copy of the current File object in the given destination folder.
|
||||
|
||||
@@ -6,7 +6,7 @@ from networkx import MultiGraph
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.core import Action, ActionManager, AllowAllValidator, SimComponent
|
||||
from primaite.simulator.core import Action, ActionManager, SimComponent
|
||||
from primaite.simulator.network.hardware.base import Link, NIC, Node, SwitchPort
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import Router
|
||||
@@ -45,12 +45,12 @@ class Network(SimComponent):
|
||||
|
||||
def _init_action_manager(self) -> ActionManager:
|
||||
am = super()._init_action_manager()
|
||||
|
||||
self._node_action_manager = ActionManager()
|
||||
am.add_action(
|
||||
"node",
|
||||
Action(
|
||||
func=lambda request, context: self.nodes[request.pop(0)].apply_action(request, context),
|
||||
validator=AllowAllValidator(),
|
||||
func=self._node_action_manager
|
||||
# func=lambda request, context: self.nodes[request.pop(0)].apply_action(request, context),
|
||||
),
|
||||
)
|
||||
return am
|
||||
@@ -184,7 +184,8 @@ class Network(SimComponent):
|
||||
self._node_id_map[len(self.nodes)] = node
|
||||
node.parent = self
|
||||
self._nx_graph.add_node(node.hostname)
|
||||
_LOGGER.debug(f"Added node {node.uuid} to Network {self.uuid}")
|
||||
_LOGGER.info(f"Added node {node.uuid} to Network {self.uuid}")
|
||||
self._node_action_manager.add_action(name=node.uuid, action=Action(func=node._action_manager))
|
||||
|
||||
def get_node_by_hostname(self, hostname: str) -> Optional[Node]:
|
||||
"""
|
||||
@@ -218,6 +219,7 @@ class Network(SimComponent):
|
||||
break
|
||||
node.parent = None
|
||||
_LOGGER.info(f"Removed node {node.uuid} from network {self.uuid}")
|
||||
self._node_action_manager.remove_action(name=node.uuid)
|
||||
|
||||
def connect(self, endpoint_a: Union[NIC, SwitchPort], endpoint_b: Union[NIC, SwitchPort], **kwargs) -> None:
|
||||
"""
|
||||
|
||||
@@ -12,7 +12,7 @@ from prettytable import MARKDOWN, PrettyTable
|
||||
from primaite import getLogger
|
||||
from primaite.exceptions import NetworkError
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
from primaite.simulator.core import SimComponent
|
||||
from primaite.simulator.core import Action, ActionManager, SimComponent
|
||||
from primaite.simulator.domain.account import Account
|
||||
from primaite.simulator.file_system.file_system import FileSystem
|
||||
from primaite.simulator.network.protocols.arp import ARPEntry, ARPPacket
|
||||
@@ -89,9 +89,9 @@ class NIC(SimComponent):
|
||||
"The Maximum Transmission Unit (MTU) of the NIC in Bytes. Default is 1500 B"
|
||||
wake_on_lan: bool = False
|
||||
"Indicates if the NIC supports Wake-on-LAN functionality."
|
||||
connected_node: Optional[Node] = None
|
||||
_connected_node: Optional[Node] = None
|
||||
"The Node to which the NIC is connected."
|
||||
connected_link: Optional[Link] = None
|
||||
_connected_link: Optional[Link] = None
|
||||
"The Link to which the NIC is connected."
|
||||
enabled: bool = False
|
||||
"Indicates whether the NIC is enabled."
|
||||
@@ -135,17 +135,23 @@ class NIC(SimComponent):
|
||||
{
|
||||
"ip_adress": str(self.ip_address),
|
||||
"subnet_mask": str(self.subnet_mask),
|
||||
"gateway": str(self.gateway),
|
||||
"mac_address": self.mac_address,
|
||||
"speed": self.speed,
|
||||
"mtu": self.mtu,
|
||||
"wake_on_lan": self.wake_on_lan,
|
||||
"dns_servers": self.dns_servers,
|
||||
"enabled": self.enabled,
|
||||
}
|
||||
)
|
||||
return state
|
||||
|
||||
def _init_action_manager(self) -> ActionManager:
|
||||
am = super()._init_action_manager()
|
||||
|
||||
am.add_action("enable", Action(func=lambda request, context: self.enable()))
|
||||
am.add_action("disable", Action(func=lambda request, context: self.disable()))
|
||||
|
||||
return am
|
||||
|
||||
@property
|
||||
def ip_network(self) -> IPv4Network:
|
||||
"""
|
||||
@@ -159,21 +165,21 @@ class NIC(SimComponent):
|
||||
"""Attempt to enable the NIC."""
|
||||
if self.enabled:
|
||||
return
|
||||
if not self.connected_node:
|
||||
if not self._connected_node:
|
||||
_LOGGER.error(f"NIC {self} cannot be enabled as it is not connected to a Node")
|
||||
return
|
||||
if self.connected_node.operating_state != NodeOperatingState.ON:
|
||||
self.connected_node.sys_log.error(f"NIC {self} cannot be enabled as the endpoint is not turned on")
|
||||
if self._connected_node.operating_state != NodeOperatingState.ON:
|
||||
self._connected_node.sys_log.error(f"NIC {self} cannot be enabled as the endpoint is not turned on")
|
||||
return
|
||||
if not self.connected_link:
|
||||
if not self._connected_link:
|
||||
_LOGGER.error(f"NIC {self} cannot be enabled as it is not connected to a Link")
|
||||
return
|
||||
|
||||
self.enabled = True
|
||||
self.connected_node.sys_log.info(f"NIC {self} enabled")
|
||||
self.pcap = PacketCapture(hostname=self.connected_node.hostname, ip_address=self.ip_address)
|
||||
if self.connected_link:
|
||||
self.connected_link.endpoint_up()
|
||||
self._connected_node.sys_log.info(f"NIC {self} enabled")
|
||||
self.pcap = PacketCapture(hostname=self._connected_node.hostname, ip_address=self.ip_address)
|
||||
if self._connected_link:
|
||||
self._connected_link.endpoint_up()
|
||||
|
||||
def disable(self):
|
||||
"""Disable the NIC."""
|
||||
@@ -181,12 +187,12 @@ class NIC(SimComponent):
|
||||
return
|
||||
|
||||
self.enabled = False
|
||||
if self.connected_node:
|
||||
self.connected_node.sys_log.info(f"NIC {self} disabled")
|
||||
if self._connected_node:
|
||||
self._connected_node.sys_log.info(f"NIC {self} disabled")
|
||||
else:
|
||||
_LOGGER.debug(f"NIC {self} disabled")
|
||||
if self.connected_link:
|
||||
self.connected_link.endpoint_down()
|
||||
if self._connected_link:
|
||||
self._connected_link.endpoint_down()
|
||||
|
||||
def connect_link(self, link: Link):
|
||||
"""
|
||||
@@ -195,26 +201,26 @@ class NIC(SimComponent):
|
||||
:param link: The link to which the NIC is connected.
|
||||
:type link: :class:`~primaite.simulator.network.transmission.physical_layer.Link`
|
||||
"""
|
||||
if self.connected_link:
|
||||
if self._connected_link:
|
||||
_LOGGER.error(f"Cannot connect Link to NIC ({self.mac_address}) as it already has a connection")
|
||||
return
|
||||
|
||||
if self.connected_link == link:
|
||||
if self._connected_link == link:
|
||||
_LOGGER.error(f"Cannot connect Link to NIC ({self.mac_address}) as it is already connected")
|
||||
return
|
||||
|
||||
# TODO: Inform the Node that a link has been connected
|
||||
self.connected_link = link
|
||||
self._connected_link = link
|
||||
self.enable()
|
||||
_LOGGER.debug(f"NIC {self} connected to Link {link}")
|
||||
|
||||
def disconnect_link(self):
|
||||
"""Disconnect the NIC from the connected Link."""
|
||||
if self.connected_link.endpoint_a == self:
|
||||
self.connected_link.endpoint_a = None
|
||||
if self.connected_link.endpoint_b == self:
|
||||
self.connected_link.endpoint_b = None
|
||||
self.connected_link = None
|
||||
if self._connected_link.endpoint_a == self:
|
||||
self._connected_link.endpoint_a = None
|
||||
if self._connected_link.endpoint_b == self:
|
||||
self._connected_link.endpoint_b = None
|
||||
self._connected_link = None
|
||||
|
||||
def add_dns_server(self, ip_address: IPv4Address):
|
||||
"""
|
||||
@@ -244,7 +250,7 @@ class NIC(SimComponent):
|
||||
if self.enabled:
|
||||
frame.set_sent_timestamp()
|
||||
self.pcap.capture(frame)
|
||||
self.connected_link.transmit_frame(sender_nic=self, frame=frame)
|
||||
self._connected_link.transmit_frame(sender_nic=self, frame=frame)
|
||||
return True
|
||||
# Cannot send Frame as the NIC is not enabled
|
||||
return False
|
||||
@@ -263,7 +269,7 @@ class NIC(SimComponent):
|
||||
self.pcap.capture(frame)
|
||||
# If this destination or is broadcast
|
||||
if frame.ethernet.dst_mac_addr == self.mac_address or frame.ethernet.dst_mac_addr == "ff:ff:ff:ff:ff:ff":
|
||||
self.connected_node.receive_frame(frame=frame, from_nic=self)
|
||||
self._connected_node.receive_frame(frame=frame, from_nic=self)
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -288,9 +294,9 @@ class SwitchPort(SimComponent):
|
||||
"The speed of the SwitchPort in Mbps. Default is 100 Mbps."
|
||||
mtu: int = 1500
|
||||
"The Maximum Transmission Unit (MTU) of the SwitchPort in Bytes. Default is 1500 B"
|
||||
connected_node: Optional[Node] = None
|
||||
_connected_node: Optional[Node] = None
|
||||
"The Node to which the SwitchPort is connected."
|
||||
connected_link: Optional[Link] = None
|
||||
_connected_link: Optional[Link] = None
|
||||
"The Link to which the SwitchPort is connected."
|
||||
enabled: bool = False
|
||||
"Indicates whether the SwitchPort is enabled."
|
||||
@@ -327,31 +333,31 @@ class SwitchPort(SimComponent):
|
||||
if self.enabled:
|
||||
return
|
||||
|
||||
if not self.connected_node:
|
||||
if not self._connected_node:
|
||||
_LOGGER.error(f"SwitchPort {self} cannot be enabled as it is not connected to a Node")
|
||||
return
|
||||
|
||||
if self.connected_node.operating_state != NodeOperatingState.ON:
|
||||
self.connected_node.sys_log.info(f"SwitchPort {self} cannot be enabled as the endpoint is not turned on")
|
||||
if self._connected_node.operating_state != NodeOperatingState.ON:
|
||||
self._connected_node.sys_log.info(f"SwitchPort {self} cannot be enabled as the endpoint is not turned on")
|
||||
return
|
||||
|
||||
self.enabled = True
|
||||
self.connected_node.sys_log.info(f"SwitchPort {self} enabled")
|
||||
self.pcap = PacketCapture(hostname=self.connected_node.hostname, switch_port_number=self.port_num)
|
||||
if self.connected_link:
|
||||
self.connected_link.endpoint_up()
|
||||
self._connected_node.sys_log.info(f"SwitchPort {self} enabled")
|
||||
self.pcap = PacketCapture(hostname=self._connected_node.hostname, switch_port_number=self.port_num)
|
||||
if self._connected_link:
|
||||
self._connected_link.endpoint_up()
|
||||
|
||||
def disable(self):
|
||||
"""Disable the SwitchPort."""
|
||||
if not self.enabled:
|
||||
return
|
||||
self.enabled = False
|
||||
if self.connected_node:
|
||||
self.connected_node.sys_log.info(f"SwitchPort {self} disabled")
|
||||
if self._connected_node:
|
||||
self._connected_node.sys_log.info(f"SwitchPort {self} disabled")
|
||||
else:
|
||||
_LOGGER.debug(f"SwitchPort {self} disabled")
|
||||
if self.connected_link:
|
||||
self.connected_link.endpoint_down()
|
||||
if self._connected_link:
|
||||
self._connected_link.endpoint_down()
|
||||
|
||||
def connect_link(self, link: Link):
|
||||
"""
|
||||
@@ -359,26 +365,26 @@ class SwitchPort(SimComponent):
|
||||
|
||||
:param link: The link to which the SwitchPort is connected.
|
||||
"""
|
||||
if self.connected_link:
|
||||
if self._connected_link:
|
||||
_LOGGER.error(f"Cannot connect link to SwitchPort {self.mac_address} as it already has a connection")
|
||||
return
|
||||
|
||||
if self.connected_link == link:
|
||||
if self._connected_link == link:
|
||||
_LOGGER.error(f"Cannot connect Link to SwitchPort {self.mac_address} as it is already connected")
|
||||
return
|
||||
|
||||
# TODO: Inform the Switch that a link has been connected
|
||||
self.connected_link = link
|
||||
self._connected_link = link
|
||||
_LOGGER.debug(f"SwitchPort {self} connected to Link {link}")
|
||||
self.enable()
|
||||
|
||||
def disconnect_link(self):
|
||||
"""Disconnect the SwitchPort from the connected Link."""
|
||||
if self.connected_link.endpoint_a == self:
|
||||
self.connected_link.endpoint_a = None
|
||||
if self.connected_link.endpoint_b == self:
|
||||
self.connected_link.endpoint_b = None
|
||||
self.connected_link = None
|
||||
if self._connected_link.endpoint_a == self:
|
||||
self._connected_link.endpoint_a = None
|
||||
if self._connected_link.endpoint_b == self:
|
||||
self._connected_link.endpoint_b = None
|
||||
self._connected_link = None
|
||||
|
||||
def send_frame(self, frame: Frame) -> bool:
|
||||
"""
|
||||
@@ -388,7 +394,7 @@ class SwitchPort(SimComponent):
|
||||
"""
|
||||
if self.enabled:
|
||||
self.pcap.capture(frame)
|
||||
self.connected_link.transmit_frame(sender_nic=self, frame=frame)
|
||||
self._connected_link.transmit_frame(sender_nic=self, frame=frame)
|
||||
return True
|
||||
# Cannot send Frame as the SwitchPort is not enabled
|
||||
return False
|
||||
@@ -404,7 +410,7 @@ class SwitchPort(SimComponent):
|
||||
if self.enabled:
|
||||
frame.decrement_ttl()
|
||||
self.pcap.capture(frame)
|
||||
connected_node: Node = self.connected_node
|
||||
connected_node: Node = self._connected_node
|
||||
connected_node.forward_frame(frame=frame, incoming_port=self)
|
||||
return True
|
||||
return False
|
||||
@@ -937,6 +943,34 @@ class Node(SimComponent):
|
||||
self.arp.nics = self.nics
|
||||
self.session_manager.software_manager = self.software_manager
|
||||
|
||||
def _init_action_manager(self) -> ActionManager:
|
||||
# TODO: I see that this code is really confusing and hard to read right now... I think some of these things will
|
||||
# need a better name and better documentation.
|
||||
am = super()._init_action_manager()
|
||||
# since there are potentially many services, create an action manager that can map service name
|
||||
self._service_action_manager = ActionManager()
|
||||
am.add_action("service", Action(func=self._service_action_manager))
|
||||
self._nic_action_manager = ActionManager()
|
||||
am.add_action("nic", Action(func=self._nic_action_manager))
|
||||
|
||||
am.add_action("file_system", Action(func=self.file_system._action_manager))
|
||||
|
||||
# currently we don't have any applications nor processes, so these will be empty
|
||||
self._process_action_manager = ActionManager()
|
||||
am.add_action("process", Action(func=self._process_action_manager))
|
||||
self._application_action_manager = ActionManager()
|
||||
am.add_action("application", Action(func=self._application_action_manager))
|
||||
|
||||
am.add_action("scan", Action(func=lambda request, context: ...)) # TODO implement OS scan
|
||||
|
||||
am.add_action("shutdown", Action(func=lambda request, context: self.power_off()))
|
||||
am.add_action("startup", Action(func=lambda request, context: self.power_on()))
|
||||
am.add_action("reset", Action(func=lambda request, context: ...)) # TODO implement node reset
|
||||
am.add_action("logon", Action(func=lambda request, context: ...)) # TODO implement logon action
|
||||
am.add_action("logoff", Action(func=lambda request, context: ...)) # TODO implement logoff action
|
||||
|
||||
return am
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
@@ -1004,7 +1038,7 @@ class Node(SimComponent):
|
||||
self.operating_state = NodeOperatingState.ON
|
||||
self.sys_log.info("Turned on")
|
||||
for nic in self.nics.values():
|
||||
if nic.connected_link:
|
||||
if nic._connected_link:
|
||||
nic.enable()
|
||||
|
||||
def power_off(self):
|
||||
@@ -1025,11 +1059,12 @@ class Node(SimComponent):
|
||||
if nic.uuid not in self.nics:
|
||||
self.nics[nic.uuid] = nic
|
||||
self.ethernet_port[len(self.nics)] = nic
|
||||
nic.connected_node = self
|
||||
nic._connected_node = self
|
||||
nic.parent = self
|
||||
self.sys_log.info(f"Connected NIC {nic}")
|
||||
if self.operating_state == NodeOperatingState.ON:
|
||||
nic.enable()
|
||||
self._nic_action_manager.add_action(nic.uuid, Action(func=nic._action_manager))
|
||||
else:
|
||||
msg = f"Cannot connect NIC {nic} as it is already connected"
|
||||
self.sys_log.logger.error(msg)
|
||||
@@ -1054,6 +1089,7 @@ class Node(SimComponent):
|
||||
nic.parent = None
|
||||
nic.disable()
|
||||
self.sys_log.info(f"Disconnected NIC {nic}")
|
||||
self._nic_action_manager.remove_action(nic.uuid)
|
||||
else:
|
||||
msg = f"Cannot disconnect NIC {nic} as it is not connected"
|
||||
self.sys_log.logger.error(msg)
|
||||
@@ -1150,7 +1186,8 @@ class Node(SimComponent):
|
||||
service.parent = self
|
||||
service.install() # Perform any additional setup, such as creating files for this service on the node.
|
||||
self.sys_log.info(f"Installed service {service.name}")
|
||||
_LOGGER.debug(f"Added service {service.uuid} to node {self.uuid}")
|
||||
_LOGGER.info(f"Added service {service.uuid} to node {self.uuid}")
|
||||
self._service_action_manager.add_action(service.uuid, Action(func=service._action_manager))
|
||||
|
||||
def uninstall_service(self, service: Service) -> None:
|
||||
"""Uninstall and completely remove service from this node.
|
||||
@@ -1165,7 +1202,8 @@ class Node(SimComponent):
|
||||
self.services.pop(service.uuid)
|
||||
service.parent = None
|
||||
self.sys_log.info(f"Uninstalled service {service.name}")
|
||||
_LOGGER.debug(f"Removed service {service.uuid} from node {self.uuid}")
|
||||
_LOGGER.info(f"Removed service {service.uuid} from node {self.uuid}")
|
||||
self._service_action_manager.remove_action(service.uuid)
|
||||
|
||||
def __contains__(self, item: Any) -> bool:
|
||||
if isinstance(item, Service):
|
||||
@@ -1188,7 +1226,7 @@ class Switch(Node):
|
||||
if not self.switch_ports:
|
||||
self.switch_ports = {i: SwitchPort() for i in range(1, self.num_ports + 1)}
|
||||
for port_num, port in self.switch_ports.items():
|
||||
port.connected_node = self
|
||||
port._connected_node = self
|
||||
port.parent = self
|
||||
port.port_num = port_num
|
||||
|
||||
@@ -1261,7 +1299,7 @@ class Switch(Node):
|
||||
_LOGGER.error(msg)
|
||||
raise NetworkError(msg)
|
||||
|
||||
if port.connected_link != link:
|
||||
if port._connected_link != link:
|
||||
msg = f"The link does not match the connection at port number {port_number}"
|
||||
_LOGGER.error(msg)
|
||||
raise NetworkError(msg)
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite.simulator.core import SimComponent
|
||||
from primaite.simulator.core import Action, ActionManager, SimComponent
|
||||
from primaite.simulator.network.hardware.base import ARPCache, ICMP, NIC, Node
|
||||
from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
|
||||
from primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol
|
||||
@@ -87,6 +87,36 @@ class AccessControlList(SimComponent):
|
||||
super().__init__(**kwargs)
|
||||
self._acl = [None] * (self.max_acl_rules - 1)
|
||||
|
||||
def _init_action_manager(self) -> ActionManager:
|
||||
am = super()._init_action_manager()
|
||||
|
||||
# When the request reaches this action, it should now contain solely positional args for the 'add_rule' action.
|
||||
# POSITIONAL ARGUMENTS:
|
||||
# 0: action (str name of an ACLAction)
|
||||
# 1: protocol (str name of an IPProtocol)
|
||||
# 2: source ip address (str castable to IPV4Address (e.g. '10.10.1.2'))
|
||||
# 3: source port (str name of a Port (e.g. "HTTP")) # should we be using value, such as 80 or 443?
|
||||
# 4: destination ip address (str castable to IPV4Address (e.g. '10.10.1.2'))
|
||||
# 5: destination port (str name of a Port (e.g. "HTTP"))
|
||||
# 6: position (int)
|
||||
am.add_action(
|
||||
"add_rule",
|
||||
Action(
|
||||
func=lambda request, context: self.add_rule(
|
||||
ACLAction[request[0]],
|
||||
IPProtocol[request[1]],
|
||||
IPv4Address[request[2]],
|
||||
Port[request[3]],
|
||||
IPv4Address[request[4]],
|
||||
Port[request[5]],
|
||||
int(request[6]),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
am.add_action("remove_rule", Action(func=lambda request, context: self.remove_rule(int(request[0]))))
|
||||
return am
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describes the current state of the AccessControlList.
|
||||
@@ -596,6 +626,11 @@ class Router(Node):
|
||||
self.arp.nics = self.nics
|
||||
self.icmp.arp = self.arp
|
||||
|
||||
def _init_action_manager(self) -> ActionManager:
|
||||
am = super()._init_action_manager()
|
||||
am.add_action("acl", Action(func=self.acl._action_manager))
|
||||
return am
|
||||
|
||||
def _get_port_of_nic(self, target_nic: NIC) -> Optional[int]:
|
||||
"""
|
||||
Retrieve the port number for a given NIC.
|
||||
|
||||
@@ -30,7 +30,7 @@ class Switch(Node):
|
||||
if not self.switch_ports:
|
||||
self.switch_ports = {i: SwitchPort() for i in range(1, self.num_ports + 1)}
|
||||
for port_num, port in self.switch_ports.items():
|
||||
port.connected_node = self
|
||||
port._connected_node = self
|
||||
port.parent = self
|
||||
port.port_num = port_num
|
||||
|
||||
@@ -113,7 +113,7 @@ class Switch(Node):
|
||||
_LOGGER.error(msg)
|
||||
raise NetworkError(msg)
|
||||
|
||||
if port.connected_link != link:
|
||||
if port._connected_link != link:
|
||||
msg = f"The link does not match the connection at port number {port_number}"
|
||||
_LOGGER.error(msg)
|
||||
raise NetworkError(msg)
|
||||
|
||||
@@ -24,19 +24,9 @@ class Simulation(SimComponent):
|
||||
def _init_action_manager(self) -> ActionManager:
|
||||
am = super()._init_action_manager()
|
||||
# pass through network actions to the network objects
|
||||
am.add_action(
|
||||
"network",
|
||||
Action(
|
||||
func=lambda request, context: self.network.apply_action(request, context), validator=AllowAllValidator()
|
||||
),
|
||||
)
|
||||
am.add_action("network", Action(func=self.network._action_manager))
|
||||
# pass through domain actions to the domain object
|
||||
am.add_action(
|
||||
"domain",
|
||||
Action(
|
||||
func=lambda request, context: self.domain.apply_action(request, context), validator=AllowAllValidator()
|
||||
),
|
||||
)
|
||||
am.add_action("domain", Action(func=self.domain._action_manager))
|
||||
return am
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
|
||||
@@ -204,7 +204,7 @@ class IOSoftware(Software):
|
||||
"max_sessions": self.max_sessions,
|
||||
"tcp": self.tcp,
|
||||
"udp": self.udp,
|
||||
"ports": [port.name for port in self.ports], # TODO: not sure if this should be port.name or port.value
|
||||
"port": self.port.value,
|
||||
}
|
||||
)
|
||||
return state
|
||||
|
||||
Reference in New Issue
Block a user