diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 84bd3f39..4d28328e 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -492,9 +492,9 @@ class NetworkACLAddRuleAction(AbstractAction): "add_rule", permission_str, protocol, - src_ip, + str(src_ip), src_port, - dst_ip, + str(dst_ip), dst_port, position, ] diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 394a8154..c94cb3ad 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -165,9 +165,13 @@ class PrimaiteGame: for _, agent in self.agents.items(): obs = agent.observation_manager.current_observation action_choice, options = agent.get_action(obs, timestep=self.step_counter) - agent_actions[agent.agent_name] = (action_choice, options) request = agent.format_request(action_choice, options) - self.simulation.apply_request(request) + response = self.simulation.apply_request(request) + agent_actions[agent.agent_name] = { + "action": action_choice, + "parameters": options, + "response": response.model_dump(), + } return agent_actions def advance_timestep(self) -> None: diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index 3e21ed16..ed2b4d62 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -93,7 +93,7 @@ class PrimaiteIO: { "episode": episode, "timestep": timestep, - "agent_actions": {k: {"action": v[0], "parameters": v[1]} for k, v in agent_actions.items()}, + "agent_actions": agent_actions, } ] ) diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 64f33f6a..02481661 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -43,7 +43,7 @@ class RequestType(BaseModel): the request can be performed or not. """ - func: Callable[[List[str], Dict], RequestResponse] + func: Callable[[List[Union[str, int, float]], Dict], RequestResponse] """ ``func`` is a function that accepts a request and a context dict. Typically this would be a lambda function that invokes a class method of your SimComponent. For example if the component is a node and the request type is for @@ -73,7 +73,7 @@ class RequestManager(BaseModel): """maps request name to an RequestType object.""" @validate_call - def __call__(self, request: List[str], context: Dict) -> RequestResponse: + def __call__(self, request: List[Union[str, int, float]], context: Dict) -> RequestResponse: """ Process an request request. @@ -206,7 +206,8 @@ class SimComponent(BaseModel): } return state - def apply_request(self, request: List[str], context: Dict = {}) -> None: + @validate_call + def apply_request(self, request: List[Union[str, int, float]], context: Dict = {}) -> RequestResponse: """ Apply a request to a simulation component. Request data is passed in as a 'namespaced' list of strings. diff --git a/src/primaite/simulator/file_system/file.py b/src/primaite/simulator/file_system/file.py index d9b02e8e..4dc222fb 100644 --- a/src/primaite/simulator/file_system/file.py +++ b/src/primaite/simulator/file_system/file.py @@ -100,15 +100,16 @@ class File(FileSystemItemABC): state["file_type"] = self.file_type.name return state - def scan(self) -> None: + def scan(self) -> bool: """Updates the visible statuses of the file.""" if self.deleted: self.sys_log.error(f"Unable to scan deleted file {self.folder_name}/{self.name}") - return + return False path = self.folder.name + "/" + self.name self.sys_log.info(f"Scanning file {self.sim_path if self.sim_path else path}") self.visible_health_status = self.health_status + return True def reveal_to_red(self) -> None: """Reveals the folder/file to the red agent.""" @@ -117,7 +118,7 @@ class File(FileSystemItemABC): return self.revealed_to_red = True - def check_hash(self) -> None: + def check_hash(self) -> bool: """ Check if the file has been changed. @@ -127,7 +128,7 @@ class File(FileSystemItemABC): """ if self.deleted: self.sys_log.error(f"Unable to check hash of deleted file {self.folder_name}/{self.name}") - return + return False current_hash = None # if file is real, read the file contents @@ -149,12 +150,13 @@ class File(FileSystemItemABC): # if the previous hash and current hash do not match, mark file as corrupted if self.previous_hash is not current_hash: self.corrupt() + return True - def repair(self) -> None: + def repair(self) -> bool: """Repair a corrupted File by setting the status to FileSystemItemStatus.GOOD.""" if self.deleted: self.sys_log.error(f"Unable to repair deleted file {self.folder_name}/{self.name}") - return + return False # set file status to good if corrupt if self.health_status == FileSystemItemHealthStatus.CORRUPT: @@ -162,12 +164,13 @@ class File(FileSystemItemABC): path = self.folder.name + "/" + self.name self.sys_log.info(f"Repaired file {self.sim_path if self.sim_path else path}") + return True - def corrupt(self) -> None: + def corrupt(self) -> bool: """Corrupt a File by setting the status to FileSystemItemStatus.CORRUPT.""" if self.deleted: self.sys_log.error(f"Unable to corrupt deleted file {self.folder_name}/{self.name}") - return + return False # set file status to good if corrupt if self.health_status == FileSystemItemHealthStatus.GOOD: @@ -175,24 +178,27 @@ class File(FileSystemItemABC): path = self.folder.name + "/" + self.name self.sys_log.info(f"Corrupted file {self.sim_path if self.sim_path else path}") + return True - def restore(self) -> None: + def restore(self) -> bool: """Determines if the file needs to be repaired or unmarked as deleted.""" if self.deleted: self.deleted = False - return + return True if self.health_status == FileSystemItemHealthStatus.CORRUPT: self.health_status = FileSystemItemHealthStatus.GOOD path = self.folder.name + "/" + self.name self.sys_log.info(f"Restored file {self.sim_path if self.sim_path else path}") + return True - def delete(self): + def delete(self) -> bool: """Marks the file as deleted.""" if self.deleted: self.sys_log.error(f"Unable to delete an already deleted file {self.folder_name}/{self.name}") - return + return False self.deleted = True self.sys_log.info(f"File deleted {self.folder_name}/{self.name}") + return True diff --git a/src/primaite/simulator/file_system/folder.py b/src/primaite/simulator/file_system/folder.py index 9ef1ae59..fff08b23 100644 --- a/src/primaite/simulator/file_system/folder.py +++ b/src/primaite/simulator/file_system/folder.py @@ -307,7 +307,7 @@ class Folder(FileSystemItemABC): """Returns true if the folder is being quarantined.""" pass - def scan(self, instant_scan: bool = False) -> None: + def scan(self, instant_scan: bool = False) -> bool: """ Update Folder visible status. @@ -315,7 +315,7 @@ class Folder(FileSystemItemABC): """ if self.deleted: self.sys_log.error(f"Unable to scan deleted folder {self.name}") - return + return False if instant_scan: for file_id in self.files: @@ -323,7 +323,7 @@ class Folder(FileSystemItemABC): file.scan() if file.visible_health_status == FileSystemItemHealthStatus.CORRUPT: self.visible_health_status = FileSystemItemHealthStatus.CORRUPT - return + return True if self.scan_countdown <= 0: # scan one file per timestep @@ -332,6 +332,7 @@ class Folder(FileSystemItemABC): else: # scan already in progress self.sys_log.info(f"Scan is already in progress {self.name} (id: {self.uuid})") + return True def reveal_to_red(self, instant_scan: bool = False): """ @@ -358,7 +359,7 @@ class Folder(FileSystemItemABC): # scan already in progress self.sys_log.info(f"Red Agent Scan is already in progress {self.name} (id: {self.uuid})") - def check_hash(self) -> None: + def check_hash(self) -> bool: """ Runs a :func:`check_hash` on all files in the folder. @@ -371,7 +372,7 @@ class Folder(FileSystemItemABC): """ if self.deleted: self.sys_log.error(f"Unable to check hash of deleted folder {self.name}") - return + return False # iterate through the files and run a check hash no_corrupted_files = True @@ -387,12 +388,13 @@ class Folder(FileSystemItemABC): self.corrupt() self.sys_log.info(f"Checking hash of folder {self.name} (id: {self.uuid})") + return True - def repair(self) -> None: + def repair(self) -> bool: """Repair a corrupted Folder by setting the folder and containing files status to FileSystemItemStatus.GOOD.""" if self.deleted: self.sys_log.error(f"Unable to repair deleted folder {self.name}") - return + return False # iterate through the files in the folder for file_id in self.files: @@ -406,8 +408,9 @@ class Folder(FileSystemItemABC): self.health_status = FileSystemItemHealthStatus.GOOD self.sys_log.info(f"Repaired folder {self.name} (id: {self.uuid})") + return True - def restore(self) -> None: + def restore(self) -> bool: """ If a Folder is corrupted, run a repair on the folder and its child files. @@ -423,12 +426,13 @@ class Folder(FileSystemItemABC): else: # scan already in progress self.sys_log.info(f"Folder restoration already in progress {self.name} (id: {self.uuid})") + return True - def corrupt(self) -> None: + def corrupt(self) -> bool: """Corrupt a File by setting the folder and containing files status to FileSystemItemStatus.CORRUPT.""" if self.deleted: self.sys_log.error(f"Unable to corrupt deleted folder {self.name}") - return + return False # iterate through the files in the folder for file_id in self.files: @@ -439,11 +443,13 @@ class Folder(FileSystemItemABC): self.health_status = FileSystemItemHealthStatus.CORRUPT self.sys_log.info(f"Corrupted folder {self.name} (id: {self.uuid})") + return True - def delete(self): + def delete(self) -> bool: """Marks the file as deleted. Prevents agent actions from occuring.""" if self.deleted: self.sys_log.error(f"Unable to delete an already deleted folder {self.name}") - return + return False self.deleted = True + return True diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 3349bed4..d5945653 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -287,24 +287,24 @@ class WiredNetworkInterface(NetworkInterface, ABC): _connected_link: Optional[Link] = None "The network link to which the network interface is connected." - def enable(self): + def enable(self) -> bool: """Attempt to enable the network interface.""" if self.enabled: - return + return True if not self._connected_node: _LOGGER.error(f"Interface {self} cannot be enabled as it is not connected to a Node") - return + return False if self._connected_node.operating_state != NodeOperatingState.ON: self._connected_node.sys_log.info( f"Interface {self} cannot be enabled as the connected Node is not powered on" ) - return + return False if not self._connected_link: self._connected_node.sys_log.info(f"Interface {self} cannot be enabled as there is no Link connected.") - return + return False self.enabled = True self._connected_node.sys_log.info(f"Network Interface {self} enabled") @@ -313,11 +313,12 @@ class WiredNetworkInterface(NetworkInterface, ABC): ) if self._connected_link: self._connected_link.endpoint_up() + return True - def disable(self): + def disable(self) -> bool: """Disable the network interface.""" if not self.enabled: - return + return True self.enabled = False if self._connected_node: self._connected_node.sys_log.info(f"Network Interface {self} disabled") @@ -325,6 +326,7 @@ class WiredNetworkInterface(NetworkInterface, ABC): _LOGGER.debug(f"Interface {self} disabled") if self._connected_link: self._connected_link.endpoint_down() + return True def connect_link(self, link: Link): """ @@ -499,7 +501,7 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC): return state - def enable(self): + def enable(self) -> bool: """ Enables this wired network interface and attempts to send a "hello" message to the default gateway. @@ -515,8 +517,10 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC): try: pass self._connected_node.default_gateway_hello() + return True except AttributeError: pass + return False @abstractmethod def receive_frame(self, frame: Frame) -> bool: diff --git a/src/primaite/simulator/network/hardware/network_interface/wireless/wireless_access_point.py b/src/primaite/simulator/network/hardware/network_interface/wireless/wireless_access_point.py index 721814f8..4b73b6a8 100644 --- a/src/primaite/simulator/network/hardware/network_interface/wireless/wireless_access_point.py +++ b/src/primaite/simulator/network/hardware/network_interface/wireless/wireless_access_point.py @@ -51,13 +51,15 @@ class WirelessAccessPoint(IPWirelessNetworkInterface): return state - def enable(self): + def enable(self) -> bool: """Enable the interface.""" pass + return True - def disable(self): + def disable(self) -> bool: """Disable the interface.""" pass + return True def send_frame(self, frame: Frame) -> bool: """ diff --git a/src/primaite/simulator/network/hardware/network_interface/wireless/wireless_nic.py b/src/primaite/simulator/network/hardware/network_interface/wireless/wireless_nic.py index 7b8f6f54..2e0a1823 100644 --- a/src/primaite/simulator/network/hardware/network_interface/wireless/wireless_nic.py +++ b/src/primaite/simulator/network/hardware/network_interface/wireless/wireless_nic.py @@ -48,13 +48,15 @@ class WirelessNIC(IPWirelessNetworkInterface): return state - def enable(self): + def enable(self) -> bool: """Enable the interface.""" pass + return True - def disable(self): + def disable(self) -> bool: """Disable the interface.""" pass + return True def send_frame(self, frame: Frame) -> bool: """