Add call validation

This commit is contained in:
Marek Wolan
2024-03-08 15:57:43 +00:00
parent beb51834f9
commit 0447a05084
9 changed files with 69 additions and 44 deletions

View File

@@ -492,9 +492,9 @@ class NetworkACLAddRuleAction(AbstractAction):
"add_rule",
permission_str,
protocol,
src_ip,
str(src_ip),
src_port,
dst_ip,
str(dst_ip),
dst_port,
position,
]

View File

@@ -165,9 +165,13 @@ class PrimaiteGame:
for _, agent in self.agents.items():
obs = agent.observation_manager.current_observation
action_choice, options = agent.get_action(obs, timestep=self.step_counter)
agent_actions[agent.agent_name] = (action_choice, options)
request = agent.format_request(action_choice, options)
self.simulation.apply_request(request)
response = self.simulation.apply_request(request)
agent_actions[agent.agent_name] = {
"action": action_choice,
"parameters": options,
"response": response.model_dump(),
}
return agent_actions
def advance_timestep(self) -> None:

View File

@@ -93,7 +93,7 @@ class PrimaiteIO:
{
"episode": episode,
"timestep": timestep,
"agent_actions": {k: {"action": v[0], "parameters": v[1]} for k, v in agent_actions.items()},
"agent_actions": agent_actions,
}
]
)

View File

@@ -43,7 +43,7 @@ class RequestType(BaseModel):
the request can be performed or not.
"""
func: Callable[[List[str], Dict], RequestResponse]
func: Callable[[List[Union[str, int, float]], Dict], RequestResponse]
"""
``func`` is a function that accepts a request and a context dict. Typically this would be a lambda function
that invokes a class method of your SimComponent. For example if the component is a node and the request type is for
@@ -73,7 +73,7 @@ class RequestManager(BaseModel):
"""maps request name to an RequestType object."""
@validate_call
def __call__(self, request: List[str], context: Dict) -> RequestResponse:
def __call__(self, request: List[Union[str, int, float]], context: Dict) -> RequestResponse:
"""
Process an request request.
@@ -206,7 +206,8 @@ class SimComponent(BaseModel):
}
return state
def apply_request(self, request: List[str], context: Dict = {}) -> None:
@validate_call
def apply_request(self, request: List[Union[str, int, float]], context: Dict = {}) -> RequestResponse:
"""
Apply a request to a simulation component. Request data is passed in as a 'namespaced' list of strings.

View File

@@ -100,15 +100,16 @@ class File(FileSystemItemABC):
state["file_type"] = self.file_type.name
return state
def scan(self) -> None:
def scan(self) -> bool:
"""Updates the visible statuses of the file."""
if self.deleted:
self.sys_log.error(f"Unable to scan deleted file {self.folder_name}/{self.name}")
return
return False
path = self.folder.name + "/" + self.name
self.sys_log.info(f"Scanning file {self.sim_path if self.sim_path else path}")
self.visible_health_status = self.health_status
return True
def reveal_to_red(self) -> None:
"""Reveals the folder/file to the red agent."""
@@ -117,7 +118,7 @@ class File(FileSystemItemABC):
return
self.revealed_to_red = True
def check_hash(self) -> None:
def check_hash(self) -> bool:
"""
Check if the file has been changed.
@@ -127,7 +128,7 @@ class File(FileSystemItemABC):
"""
if self.deleted:
self.sys_log.error(f"Unable to check hash of deleted file {self.folder_name}/{self.name}")
return
return False
current_hash = None
# if file is real, read the file contents
@@ -149,12 +150,13 @@ class File(FileSystemItemABC):
# if the previous hash and current hash do not match, mark file as corrupted
if self.previous_hash is not current_hash:
self.corrupt()
return True
def repair(self) -> None:
def repair(self) -> bool:
"""Repair a corrupted File by setting the status to FileSystemItemStatus.GOOD."""
if self.deleted:
self.sys_log.error(f"Unable to repair deleted file {self.folder_name}/{self.name}")
return
return False
# set file status to good if corrupt
if self.health_status == FileSystemItemHealthStatus.CORRUPT:
@@ -162,12 +164,13 @@ class File(FileSystemItemABC):
path = self.folder.name + "/" + self.name
self.sys_log.info(f"Repaired file {self.sim_path if self.sim_path else path}")
return True
def corrupt(self) -> None:
def corrupt(self) -> bool:
"""Corrupt a File by setting the status to FileSystemItemStatus.CORRUPT."""
if self.deleted:
self.sys_log.error(f"Unable to corrupt deleted file {self.folder_name}/{self.name}")
return
return False
# set file status to good if corrupt
if self.health_status == FileSystemItemHealthStatus.GOOD:
@@ -175,24 +178,27 @@ class File(FileSystemItemABC):
path = self.folder.name + "/" + self.name
self.sys_log.info(f"Corrupted file {self.sim_path if self.sim_path else path}")
return True
def restore(self) -> None:
def restore(self) -> bool:
"""Determines if the file needs to be repaired or unmarked as deleted."""
if self.deleted:
self.deleted = False
return
return True
if self.health_status == FileSystemItemHealthStatus.CORRUPT:
self.health_status = FileSystemItemHealthStatus.GOOD
path = self.folder.name + "/" + self.name
self.sys_log.info(f"Restored file {self.sim_path if self.sim_path else path}")
return True
def delete(self):
def delete(self) -> bool:
"""Marks the file as deleted."""
if self.deleted:
self.sys_log.error(f"Unable to delete an already deleted file {self.folder_name}/{self.name}")
return
return False
self.deleted = True
self.sys_log.info(f"File deleted {self.folder_name}/{self.name}")
return True

View File

@@ -307,7 +307,7 @@ class Folder(FileSystemItemABC):
"""Returns true if the folder is being quarantined."""
pass
def scan(self, instant_scan: bool = False) -> None:
def scan(self, instant_scan: bool = False) -> bool:
"""
Update Folder visible status.
@@ -315,7 +315,7 @@ class Folder(FileSystemItemABC):
"""
if self.deleted:
self.sys_log.error(f"Unable to scan deleted folder {self.name}")
return
return False
if instant_scan:
for file_id in self.files:
@@ -323,7 +323,7 @@ class Folder(FileSystemItemABC):
file.scan()
if file.visible_health_status == FileSystemItemHealthStatus.CORRUPT:
self.visible_health_status = FileSystemItemHealthStatus.CORRUPT
return
return True
if self.scan_countdown <= 0:
# scan one file per timestep
@@ -332,6 +332,7 @@ class Folder(FileSystemItemABC):
else:
# scan already in progress
self.sys_log.info(f"Scan is already in progress {self.name} (id: {self.uuid})")
return True
def reveal_to_red(self, instant_scan: bool = False):
"""
@@ -358,7 +359,7 @@ class Folder(FileSystemItemABC):
# scan already in progress
self.sys_log.info(f"Red Agent Scan is already in progress {self.name} (id: {self.uuid})")
def check_hash(self) -> None:
def check_hash(self) -> bool:
"""
Runs a :func:`check_hash` on all files in the folder.
@@ -371,7 +372,7 @@ class Folder(FileSystemItemABC):
"""
if self.deleted:
self.sys_log.error(f"Unable to check hash of deleted folder {self.name}")
return
return False
# iterate through the files and run a check hash
no_corrupted_files = True
@@ -387,12 +388,13 @@ class Folder(FileSystemItemABC):
self.corrupt()
self.sys_log.info(f"Checking hash of folder {self.name} (id: {self.uuid})")
return True
def repair(self) -> None:
def repair(self) -> bool:
"""Repair a corrupted Folder by setting the folder and containing files status to FileSystemItemStatus.GOOD."""
if self.deleted:
self.sys_log.error(f"Unable to repair deleted folder {self.name}")
return
return False
# iterate through the files in the folder
for file_id in self.files:
@@ -406,8 +408,9 @@ class Folder(FileSystemItemABC):
self.health_status = FileSystemItemHealthStatus.GOOD
self.sys_log.info(f"Repaired folder {self.name} (id: {self.uuid})")
return True
def restore(self) -> None:
def restore(self) -> bool:
"""
If a Folder is corrupted, run a repair on the folder and its child files.
@@ -423,12 +426,13 @@ class Folder(FileSystemItemABC):
else:
# scan already in progress
self.sys_log.info(f"Folder restoration already in progress {self.name} (id: {self.uuid})")
return True
def corrupt(self) -> None:
def corrupt(self) -> bool:
"""Corrupt a File by setting the folder and containing files status to FileSystemItemStatus.CORRUPT."""
if self.deleted:
self.sys_log.error(f"Unable to corrupt deleted folder {self.name}")
return
return False
# iterate through the files in the folder
for file_id in self.files:
@@ -439,11 +443,13 @@ class Folder(FileSystemItemABC):
self.health_status = FileSystemItemHealthStatus.CORRUPT
self.sys_log.info(f"Corrupted folder {self.name} (id: {self.uuid})")
return True
def delete(self):
def delete(self) -> bool:
"""Marks the file as deleted. Prevents agent actions from occuring."""
if self.deleted:
self.sys_log.error(f"Unable to delete an already deleted folder {self.name}")
return
return False
self.deleted = True
return True

View File

@@ -287,24 +287,24 @@ class WiredNetworkInterface(NetworkInterface, ABC):
_connected_link: Optional[Link] = None
"The network link to which the network interface is connected."
def enable(self):
def enable(self) -> bool:
"""Attempt to enable the network interface."""
if self.enabled:
return
return True
if not self._connected_node:
_LOGGER.error(f"Interface {self} cannot be enabled as it is not connected to a Node")
return
return False
if self._connected_node.operating_state != NodeOperatingState.ON:
self._connected_node.sys_log.info(
f"Interface {self} cannot be enabled as the connected Node is not powered on"
)
return
return False
if not self._connected_link:
self._connected_node.sys_log.info(f"Interface {self} cannot be enabled as there is no Link connected.")
return
return False
self.enabled = True
self._connected_node.sys_log.info(f"Network Interface {self} enabled")
@@ -313,11 +313,12 @@ class WiredNetworkInterface(NetworkInterface, ABC):
)
if self._connected_link:
self._connected_link.endpoint_up()
return True
def disable(self):
def disable(self) -> bool:
"""Disable the network interface."""
if not self.enabled:
return
return True
self.enabled = False
if self._connected_node:
self._connected_node.sys_log.info(f"Network Interface {self} disabled")
@@ -325,6 +326,7 @@ class WiredNetworkInterface(NetworkInterface, ABC):
_LOGGER.debug(f"Interface {self} disabled")
if self._connected_link:
self._connected_link.endpoint_down()
return True
def connect_link(self, link: Link):
"""
@@ -499,7 +501,7 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC):
return state
def enable(self):
def enable(self) -> bool:
"""
Enables this wired network interface and attempts to send a "hello" message to the default gateway.
@@ -515,8 +517,10 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC):
try:
pass
self._connected_node.default_gateway_hello()
return True
except AttributeError:
pass
return False
@abstractmethod
def receive_frame(self, frame: Frame) -> bool:

View File

@@ -51,13 +51,15 @@ class WirelessAccessPoint(IPWirelessNetworkInterface):
return state
def enable(self):
def enable(self) -> bool:
"""Enable the interface."""
pass
return True
def disable(self):
def disable(self) -> bool:
"""Disable the interface."""
pass
return True
def send_frame(self, frame: Frame) -> bool:
"""

View File

@@ -48,13 +48,15 @@ class WirelessNIC(IPWirelessNetworkInterface):
return state
def enable(self):
def enable(self) -> bool:
"""Enable the interface."""
pass
return True
def disable(self):
def disable(self) -> bool:
"""Disable the interface."""
pass
return True
def send_frame(self, frame: Frame) -> bool:
"""