diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index 2375a391..624c9ca4 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -14,31 +14,36 @@ parameters: - name: matrix type: object default: - # - job_name: 'UbuntuPython38' - # py: '3.8' - # img: 'ubuntu-latest' - # every_time: false - # publish_coverage: false - - job_name: 'UbuntuPython311' - py: '3.11' + - job_name: 'UbuntuPython39' + py: '3.9' + img: 'ubuntu-latest' + every_time: false + publish_coverage: false + - job_name: 'UbuntuPython310' + py: '3.10' img: 'ubuntu-latest' every_time: true publish_coverage: true - # - job_name: 'WindowsPython38' - # py: '3.8' - # img: 'windows-latest' - # every_time: false - # publish_coverage: false + - job_name: 'UbuntuPython311' + py: '3.11' + img: 'ubuntu-latest' + every_time: false + publish_coverage: false + - job_name: 'WindowsPython39' + py: '3.9' + img: 'windows-latest' + every_time: false + publish_coverage: false - job_name: 'WindowsPython311' py: '3.11' img: 'windows-latest' every_time: false publish_coverage: false - # - job_name: 'MacOSPython38' - # py: '3.8' - # img: 'macOS-latest' - # every_time: false - # publish_coverage: false + - job_name: 'MacOSPython39' + py: '3.9' + img: 'macOS-latest' + every_time: false + publish_coverage: false - job_name: 'MacOSPython311' py: '3.11' img: 'macOS-latest' @@ -63,7 +68,7 @@ stages: displayName: 'Use Python ${{ item.py }}' - script: | - python -m pip install pre-commit + python -m pip install pre-commit>=6.1 pre-commit install pre-commit run --all-files displayName: 'Run pre-commits' @@ -71,7 +76,6 @@ stages: - script: | python -m pip install --upgrade pip==23.0.1 pip install wheel==0.38.4 --upgrade - pip install setuptools==66 --upgrade pip install build==0.10.0 pip install pytest-azurepipelines displayName: 'Install build dependencies' diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index df3bb504..d004dd6c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,7 +31,7 @@ repos: - id: isort args: [ "--profile", "black" ] - repo: http://github.com/PyCQA/flake8 - rev: 6.0.0 + rev: 6.1.0 hooks: - id: flake8 additional_dependencies: diff --git a/docs/source/primaite-dependencies.rst b/docs/source/primaite-dependencies.rst index 14a96349..ce2087ca 100644 --- a/docs/source/primaite-dependencies.rst +++ b/docs/source/primaite-dependencies.rst @@ -2,44 +2,44 @@ © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| Name | Version | License | Description | URL | -+===================+=========+====================================+=======================================================================================================+====================================================================+ -| gymnasium | 0.28.1 | MIT License | A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym). | https://farama.org | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| ipywidgets | 8.1.5 | BSD License | Jupyter interactive widgets | http://jupyter.org | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| jupyterlab | 3.6.1 | BSD License | JupyterLab computational environment | https://jupyter.org | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| kaleido | 0.2.1 | MIT | Static image export for web-based visualization libraries with zero dependencies | https://github.com/plotly/Kaleido | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| matplotlib | 3.7.1 | Python Software Foundation License | Python plotting package | https://matplotlib.org | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| networkx | 3.1 | BSD License | Python package for creating and manipulating graphs and networks | https://networkx.org/ | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| numpy | 1.23.5 | BSD License | NumPy is the fundamental package for array computing with Python. | https://www.numpy.org | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| platformdirs | 3.5.1 | MIT License | A small Python package for determining appropriate platform-specific dirs, e.g. a "user data dir". | https://github.com/platformdirs/platformdirs | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| plotly | 5.15.0 | MIT License | An open-source, interactive data visualization library for Python | https://plotly.com/python/ | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| polars | 0.20.30 | MIT License | Blazingly fast DataFrame library | https://www.pola.rs/ | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| prettytable | 3.8.0 | BSD License (BSD (3 clause)) | A simple Python library for easily displaying tabular data in a visually appealing ASCII table format | https://github.com/jazzband/prettytable | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| pydantic | 2.7.0 | MIT License | Data validation using Python type hints | https://github.com/pydantic/pydantic | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| PyYAML | 6.0 | MIT License | YAML parser and emitter for Python | https://pyyaml.org/ | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| ray | 2.32.0 | Apache 2.0 | Ray provides a simple, universal API for building distributed applications. | https://github.com/ray-project/ray | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| stable-baselines3 | 2.1.0 | MIT | Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms. | https://github.com/DLR-RM/stable-baselines3 | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| tensorflow | 2.12.0 | Apache Software License | TensorFlow is an open source machine learning framework for everyone. | https://www.tensorflow.org/ | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| typer | 0.9.0 | MIT License | Typer, build great CLIs. Easy to code. Based on Python type hints. | https://github.com/tiangolo/typer | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| Deepdiff | 8.0.1 | MIT License | Deep difference of dictionaries, iterables, strings, and any other object objects. | https://github.com/seperman/deepdiff | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| sb3_contrib | 2.1.0 | MIT License | Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code (Action Masking)| https://github.com/Stable-Baselines-Team/stable-baselines3-contrib | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| Name | Supported Version | Built Version | License | Description | URL | ++===================+=====================+===============+======================================+========================================================================================================+=====================================================================+ +| gymnasium | 0.28.1 | 0.28.1 | MIT License | A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym). | https://farama.org | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| ipywidgets | ~=8.0 | 8.1.5 | BSD License | Jupyter interactive widgets | http://jupyter.org | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| jupyterlab | 3.6.1 | 3.6.1 | BSD License | JupyterLab computational environment | https://jupyter.org | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| kaleido | ==0.2.1 | 0.2.1 | MIT | Static image export for web-based visualization libraries with zero dependencies | https://github.com/plotly/Kaleido | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| matplotlib | >=3.7.1 | 3.7.1 | Python Software Foundation License | Python plotting package | https://matplotlib.org | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| networkx | 3.1 | 3.1 | BSD License | Python package for creating and manipulating graphs and networks | https://networkx.org/ | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| numpy | ~1.23 | 1.23.5 | BSD License | NumPy is the fundamental package for array computing with Python. | https://www.numpy.org | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| platformdirs | 3.5.1 | 3.5.1 | MIT License | A small Python package for determining appropriate platform-specific dirs, e.g. a "user data dir". | https://github.com/platformdirs/platformdirs | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| plotly | 5.15 | 5.15.0 | MIT License | An open-source, interactive data visualization library for Python | https://plotly.com/python/ | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| polars | 0.20.30 | 0.20.30 | MIT License | Blazingly fast DataFrame library | https://www.pola.rs/ | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| prettytable | 3.8.0 | 3.8.0 | BSD License (BSD (3 clause)) | A simple Python library for easily displaying tabular data in a visually appealing ASCII table format | https://github.com/jazzband/prettytable | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| pydantic | 2.7.0 | 2.7.0 | MIT License | Data validation using Python type hints | https://github.com/pydantic/pydantic | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| PyYAML | >=6.0 | 6.0 | MIT License | YAML parser and emitter for Python | https://pyyaml.org/ | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| ray | >=2.20, <2.33 | 2.32.0 | Apache 2.0 | Ray provides a simple, universal API for building distributed applications. | https://github.com/ray-project/ray | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| stable-baselines3 | 2.1.0 | 2.1.0 | MIT | Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms. | https://github.com/DLR-RM/stable-baselines3 | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| tensorflow | ~=2.12 | 2.12.0 | Apache Software License | TensorFlow is an open source machine learning framework for everyone. | https://www.tensorflow.org/ | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| typer | >=0.9 | 0.9.0 | MIT License | Typer, build great CLIs. Easy to code. Based on Python type hints. | https://github.com/tiangolo/typer | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| Deepdiff | 8.0.1 | 8.0.1 | MIT License | Deep difference of dictionaries, iterables, strings, and any other object objects. | https://github.com/seperman/deepdiff | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| sb3_contrib | 2.1.0 | 2.1.0 | MIT License | Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code (Action Masking) | https://github.com/Stable-Baselines-Team/stable-baselines3-contrib | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ diff --git a/pyproject.toml b/pyproject.toml index 354df8b2..e840797c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "primaite" description = "PrimAITE (Primary-level AI Training Environment) is a simulation environment for training AI under the ARCD programme." authors = [{name="Defence Science and Technology Laboratory UK", email="oss@dstl.gov.uk"}] license = {file = "LICENSE"} -requires-python = ">=3.9, <3.12" +requires-python = ">=3.9, <3.13" dynamic = ["version", "readme"] classifiers = [ "Development Status :: 5 - Production/Stable", @@ -26,15 +26,15 @@ dependencies = [ "gymnasium==0.28.1", "jupyterlab==3.6.1", "kaleido==0.2.1", - "matplotlib==3.7.1", + "matplotlib>=3.7.1", "networkx==3.1", - "numpy==1.23.5", + "numpy~=1.23", "platformdirs==3.5.1", "plotly==5.15.0", "polars==0.20.30", "prettytable==3.8.0", - "PyYAML==6.0", - "typer[all]==0.9.0", + "PyYAML>=6.0", + "typer[all]>=0.9", "pydantic==2.7.0", "ipywidgets", "deepdiff" @@ -53,8 +53,8 @@ license-files = ["LICENSE"] [project.optional-dependencies] rl = [ "ray[rllib] >= 2.20.0, <2.33", - "tensorflow==2.12.0", - "stable-baselines3[extra]==2.1.0", + "tensorflow~=2.12", + "stable-baselines3==2.1.0", "sb3-contrib==2.1.0", ] dev = [ @@ -69,7 +69,7 @@ dev = [ "pytest-xdist==3.3.1", "pytest-cov==4.0.0", "pytest-flake8==1.1.1", - "setuptools==66", + "setuptools==75.6.0", "Sphinx==7.1.2", "sphinx-copybutton==0.5.2", "wheel==0.38.4", diff --git a/src/primaite/config/_package_data/data_manipulation.yaml b/src/primaite/config/_package_data/data_manipulation.yaml index fa10a463..b0d5d087 100644 --- a/src/primaite/config/_package_data/data_manipulation.yaml +++ b/src/primaite/config/_package_data/data_manipulation.yaml @@ -161,8 +161,8 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP diff --git a/src/primaite/config/_package_data/data_manipulation_marl.yaml b/src/primaite/config/_package_data/data_manipulation_marl.yaml index b0131c8c..e45f193e 100644 --- a/src/primaite/config/_package_data/data_manipulation_marl.yaml +++ b/src/primaite/config/_package_data/data_manipulation_marl.yaml @@ -153,8 +153,8 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP @@ -668,8 +668,8 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP diff --git a/src/primaite/game/agent/agent_log.py b/src/primaite/game/agent/agent_log.py index fac92a94..5d9dc848 100644 --- a/src/primaite/game/agent/agent_log.py +++ b/src/primaite/game/agent/agent_log.py @@ -93,7 +93,7 @@ class AgentLog: def _write_to_terminal(self, msg: str, level: str, to_terminal: bool = False): if to_terminal or SIM_OUTPUT.write_agent_log_to_terminal: - print(f"{self.agent_name}: ({ self.timestep}) ({level}) {msg}") + print(f"{self.agent_name}: ({self.timestep}) ({level}) {msg}") def debug(self, msg: str, to_terminal: bool = False): """ diff --git a/src/primaite/game/agent/observations/acl_observation.py b/src/primaite/game/agent/observations/acl_observation.py index 86a6463a..cb2cb38e 100644 --- a/src/primaite/game/agent/observations/acl_observation.py +++ b/src/primaite/game/agent/observations/acl_observation.py @@ -24,8 +24,8 @@ class ACLObservation(AbstractObservation, identifier="ACL"): """List of IP addresses.""" wildcard_list: Optional[List[str]] = None """List of wildcard strings.""" - port_list: Optional[List[int]] = None - """List of port numbers.""" + port_list: Optional[List[str]] = None + """List of port names.""" protocol_list: Optional[List[str]] = None """List of protocol names.""" num_rules: Optional[int] = None @@ -37,7 +37,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): num_rules: int, ip_list: List[IPv4Address], wildcard_list: List[str], - port_list: List[int], + port_list: List[str], protocol_list: List[str], ) -> None: """ @@ -51,8 +51,8 @@ class ACLObservation(AbstractObservation, identifier="ACL"): :type ip_list: List[IPv4Address] :param wildcard_list: List of wildcard strings. :type wildcard_list: List[str] - :param port_list: List of port numbers. - :type port_list: List[int] + :param port_list: List of port names. + :type port_list: List[str] :param protocol_list: List of protocol names. :type protocol_list: List[str] """ @@ -60,7 +60,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): self.num_rules: int = num_rules self.ip_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(ip_list)} self.wildcard_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(wildcard_list)} - self.port_to_id: Dict[int, int] = {p: i + 2 for i, p in enumerate(port_list)} + self.port_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(port_list)} self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)} self.default_observation: Dict = { i diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py index 50ca93fd..784eaa7f 100644 --- a/src/primaite/game/agent/observations/file_system_observations.py +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -190,6 +190,8 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): if self.files: self.default_observation["FILES"] = {i + 1: f.default_observation for i, f in enumerate(self.files)} + self.cached_obs: Optional[ObsType] = self.default_observation + def observe(self, state: Dict) -> ObsType: """ Generate observation based on the current state of the simulation. @@ -204,7 +206,10 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): return self.default_observation if self.file_system_requires_scan: - health_status = folder_state["visible_status"] + if not folder_state["scanned_this_step"]: + health_status = self.cached_obs["health_status"] + else: + health_status = folder_state["visible_status"] else: health_status = folder_state["health_status"] diff --git a/src/primaite/game/agent/observations/firewall_observation.py b/src/primaite/game/agent/observations/firewall_observation.py index a194bb53..44541f24 100644 --- a/src/primaite/game/agent/observations/firewall_observation.py +++ b/src/primaite/game/agent/observations/firewall_observation.py @@ -27,13 +27,13 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): """List of IP addresses for encoding ACLs.""" wildcard_list: Optional[List[str]] = None """List of IP wildcards for encoding ACLs.""" - port_list: Optional[List[int]] = None + port_list: Optional[List[str]] = None """List of ports for encoding ACLs.""" protocol_list: Optional[List[str]] = None """List of protocols for encoding ACLs.""" num_rules: Optional[int] = None """Number of rules ACL rules to show.""" - include_users: Optional[bool] = True + include_users: Optional[bool] = None """If True, report user session information.""" def __init__( @@ -41,7 +41,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): where: WhereType, ip_list: List[str], wildcard_list: List[str], - port_list: List[int], + port_list: List[str], protocol_list: List[str], num_rules: int, include_users: bool, @@ -56,8 +56,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): :type ip_list: List[str] :param wildcard_list: List of wildcard rules. :type wildcard_list: List[str] - :param port_list: List of port numbers. - :type port_list: List[int] + :param port_list: List of port names. + :type port_list: List[str] :param protocol_list: List of protocol types. :type protocol_list: List[str] :param num_rules: Number of rules configured in the firewall. @@ -139,6 +139,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): }, }, } + if self.include_users: + self.default_observation["users"] = {"local_login": 0, "remote_sessions": 0} def observe(self, state: Dict) -> ObsType: """ @@ -152,29 +154,35 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): firewall_state = access_from_nested_dict(state, self.where) if firewall_state is NOT_PRESENT_IN_STATE: return self.default_observation - obs = { - "PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)}, - "ACL": { - "INTERNAL": { - "INBOUND": self.internal_inbound_acl.observe(state), - "OUTBOUND": self.internal_outbound_acl.observe(state), + + is_on = firewall_state["operating_state"] == 1 + if not is_on: + obs = {**self.default_observation} + + else: + obs = { + "PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)}, + "ACL": { + "INTERNAL": { + "INBOUND": self.internal_inbound_acl.observe(state), + "OUTBOUND": self.internal_outbound_acl.observe(state), + }, + "DMZ": { + "INBOUND": self.dmz_inbound_acl.observe(state), + "OUTBOUND": self.dmz_outbound_acl.observe(state), + }, + "EXTERNAL": { + "INBOUND": self.external_inbound_acl.observe(state), + "OUTBOUND": self.external_outbound_acl.observe(state), + }, }, - "DMZ": { - "INBOUND": self.dmz_inbound_acl.observe(state), - "OUTBOUND": self.dmz_outbound_acl.observe(state), - }, - "EXTERNAL": { - "INBOUND": self.external_inbound_acl.observe(state), - "OUTBOUND": self.external_outbound_acl.observe(state), - }, - }, - } - if self.include_users: - sess = firewall_state["services"]["UserSessionManager"] - obs["users"] = { - "local_login": 1 if sess["current_local_user"] else 0, - "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), } + if self.include_users: + sess = firewall_state["services"]["UserSessionManager"] + obs["users"] = { + "local_login": 1 if sess["current_local_user"] else 0, + "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), + } return obs @property @@ -185,34 +193,36 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): :return: Gymnasium space representing the observation space for firewall status. :rtype: spaces.Space """ - space = spaces.Dict( - { - "PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}), - "ACL": spaces.Dict( - { - "INTERNAL": spaces.Dict( - { - "INBOUND": self.internal_inbound_acl.space, - "OUTBOUND": self.internal_outbound_acl.space, - } - ), - "DMZ": spaces.Dict( - { - "INBOUND": self.dmz_inbound_acl.space, - "OUTBOUND": self.dmz_outbound_acl.space, - } - ), - "EXTERNAL": spaces.Dict( - { - "INBOUND": self.external_inbound_acl.space, - "OUTBOUND": self.external_outbound_acl.space, - } - ), - } - ), - } - ) - return space + shape = { + "PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}), + "ACL": spaces.Dict( + { + "INTERNAL": spaces.Dict( + { + "INBOUND": self.internal_inbound_acl.space, + "OUTBOUND": self.internal_outbound_acl.space, + } + ), + "DMZ": spaces.Dict( + { + "INBOUND": self.dmz_inbound_acl.space, + "OUTBOUND": self.dmz_outbound_acl.space, + } + ), + "EXTERNAL": spaces.Dict( + { + "INBOUND": self.external_inbound_acl.space, + "OUTBOUND": self.external_outbound_acl.space, + } + ), + } + ), + } + if self.include_users: + shape["users"] = spaces.Dict( + {"local_login": spaces.Discrete(2), "remote_sessions": spaces.Discrete(self.max_users + 1)} + ) + return spaces.Dict(shape) @classmethod def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FirewallObservation: diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 03e9aca1..e46cc805 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -54,7 +54,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): """ If True, files and folders must be scanned to update the health state. If False, true state is always shown. """ - include_users: Optional[bool] = True + include_users: Optional[bool] = None """If True, report user session information.""" def __init__( @@ -191,25 +191,31 @@ class HostObservation(AbstractObservation, identifier="HOST"): if node_state is NOT_PRESENT_IN_STATE: return self.default_observation - obs = {} + is_on = node_state["operating_state"] == 1 + if not is_on: + obs = {**self.default_observation} + + else: + obs = {} + if self.services: + obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} + if self.applications: + obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)} + if self.folders: + obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} + if self.nics: + obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)} + if self.include_num_access: + obs["num_file_creations"] = node_state["file_system"]["num_file_creations"] + obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"] + if self.include_users: + sess = node_state["services"]["UserSessionManager"] + obs["users"] = { + "local_login": 1 if sess["current_local_user"] else 0, + "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), + } + obs["operating_status"] = node_state["operating_state"] - if self.services: - obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} - if self.applications: - obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)} - if self.folders: - obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} - if self.nics: - obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)} - if self.include_num_access: - obs["num_file_creations"] = node_state["file_system"]["num_file_creations"] - obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"] - if self.include_users: - sess = node_state["services"]["UserSessionManager"] - obs["users"] = { - "local_login": 1 if sess["current_local_user"] else 0, - "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), - } return obs @property diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index 03869367..0c5d11da 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -56,7 +56,7 @@ class NodesObservation(AbstractObservation, identifier="NODES"): """List of IP addresses for encoding ACLs.""" wildcard_list: Optional[List[str]] = None """List of IP wildcards for encoding ACLs.""" - port_list: Optional[List[int]] = None + port_list: Optional[List[str]] = None """List of ports for encoding ACLs.""" protocol_list: Optional[List[str]] = None """List of protocols for encoding ACLs.""" diff --git a/src/primaite/game/agent/observations/router_observation.py b/src/primaite/game/agent/observations/router_observation.py index ca455f4c..9687d083 100644 --- a/src/primaite/game/agent/observations/router_observation.py +++ b/src/primaite/game/agent/observations/router_observation.py @@ -33,13 +33,13 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): """List of IP addresses for encoding ACLs.""" wildcard_list: Optional[List[str]] = None """List of IP wildcards for encoding ACLs.""" - port_list: Optional[List[int]] = None + port_list: Optional[List[str]] = None """List of ports for encoding ACLs.""" protocol_list: Optional[List[str]] = None """List of protocols for encoding ACLs.""" num_rules: Optional[int] = None """Number of rules ACL rules to show.""" - include_users: Optional[bool] = True + include_users: Optional[bool] = None """If True, report user session information.""" def __init__( @@ -84,6 +84,8 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): } if self.ports: self.default_observation["PORTS"] = {i + 1: p.default_observation for i, p in enumerate(self.ports)} + if self.include_users: + self.default_observation["users"] = {"local_login": 0, "remote_sessions": 0} def observe(self, state: Dict) -> ObsType: """ @@ -98,16 +100,21 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): if router_state is NOT_PRESENT_IN_STATE: return self.default_observation - obs = {} - obs["ACL"] = self.acl.observe(state) - if self.ports: - obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)} - if self.include_users: - sess = router_state["services"]["UserSessionManager"] - obs["users"] = { - "local_login": 1 if sess["current_local_user"] else 0, - "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), - } + is_on = router_state["operating_state"] == 1 + if not is_on: + obs = {**self.default_observation} + + else: + obs = {} + obs["ACL"] = self.acl.observe(state) + if self.ports: + obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)} + if self.include_users: + sess = router_state["services"]["UserSessionManager"] + obs["users"] = { + "local_login": 1 if sess["current_local_user"] else 0, + "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), + } return obs @property @@ -121,6 +128,10 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): shape = {"ACL": self.acl.space} if self.ports: shape["PORTS"] = spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}) + if self.include_users: + shape["users"] = spaces.Dict( + {"local_login": spaces.Discrete(2), "remote_sessions": spaces.Discrete(self.max_users + 1)} + ) return spaces.Dict(shape) @classmethod diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 8bc37597..f59117f4 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -258,6 +258,7 @@ class PrimaiteGame: net = sim.network simulation_config = cfg.get("simulation", {}) + defaults_config = cfg.get("defaults", {}) network_config = simulation_config.get("network", {}) airspace_cfg = network_config.get("airspace", {}) frequency_max_capacity_mbps_cfg = airspace_cfg.get("frequency_max_capacity_mbps", {}) @@ -338,6 +339,18 @@ class PrimaiteGame: _LOGGER.error(msg) raise ValueError(msg) + # TODO: handle simulation defaults more cleanly + if "node_start_up_duration" in defaults_config: + new_node.start_up_duration = defaults_config["node_startup_duration"] + if "node_shut_down_duration" in defaults_config: + new_node.shut_down_duration = defaults_config["node_shut_down_duration"] + if "node_scan_duration" in defaults_config: + new_node.node_scan_duration = defaults_config["node_scan_duration"] + if "folder_scan_duration" in defaults_config: + new_node.file_system._default_folder_scan_duration = defaults_config["folder_scan_duration"] + if "folder_restore_duration" in defaults_config: + new_node.file_system._default_folder_restore_duration = defaults_config["folder_restore_duration"] + if "users" in node_cfg and new_node.software_manager.software.get("UserManager"): user_manager: UserManager = new_node.software_manager.software["UserManager"] # noqa for user_cfg in node_cfg["users"]: @@ -384,6 +397,15 @@ class PrimaiteGame: msg = f"Configuration contains an invalid service type: {service_type}" _LOGGER.error(msg) raise ValueError(msg) + + # TODO: handle simulation defaults more cleanly + if "service_fix_duration" in defaults_config: + new_service.fixing_duration = defaults_config["service_fix_duration"] + if "service_restart_duration" in defaults_config: + new_service.restart_duration = defaults_config["service_restart_duration"] + if "service_install_duration" in defaults_config: + new_service.install_duration = defaults_config["service_install_duration"] + # service-dependent options if service_type == "DNSClient": if "options" in service_cfg: diff --git a/src/primaite/notebooks/Action-masking.ipynb b/src/primaite/notebooks/Action-masking.ipynb index 858b4bb6..7fde0a49 100644 --- a/src/primaite/notebooks/Action-masking.ipynb +++ b/src/primaite/notebooks/Action-masking.ipynb @@ -11,6 +11,15 @@ "PrimAITE environments support action masking. The action mask shows which of the agent's actions are applicable with the current environment state. For example, a node can only be turned on if it is currently turned off." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite setup" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb b/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb index d1154b54..756fc44f 100644 --- a/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb @@ -15,6 +15,15 @@ "*(For a full explanation of the Data Manipulation scenario, check out the data manipulation scenario notebook)*" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite setup" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb index 143bbe09..dbc6f0c1 100644 --- a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb @@ -371,6 +371,15 @@ "First, load the required modules" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite setup" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb b/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb index a832f3cc..f8691d7d 100644 --- a/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb +++ b/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb @@ -9,6 +9,15 @@ "© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite setup" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/src/primaite/notebooks/Privilege-Escalation-and Data-Loss-Example.ipynb b/src/primaite/notebooks/Privilege-Escalation-and-Data-Loss-Example.ipynb similarity index 100% rename from src/primaite/notebooks/Privilege-Escalation-and Data-Loss-Example.ipynb rename to src/primaite/notebooks/Privilege-Escalation-and-Data-Loss-Example.ipynb diff --git a/src/primaite/notebooks/Requests-and-Responses.ipynb b/src/primaite/notebooks/Requests-and-Responses.ipynb index da614c93..83aed07c 100644 --- a/src/primaite/notebooks/Requests-and-Responses.ipynb +++ b/src/primaite/notebooks/Requests-and-Responses.ipynb @@ -25,6 +25,15 @@ "Let's set up a minimal network simulation and send some requests to see how it works." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite setup" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/src/primaite/notebooks/Terminal-Processing.ipynb b/src/primaite/notebooks/Terminal-Processing.ipynb index fdf405a7..9aa4e96a 100644 --- a/src/primaite/notebooks/Terminal-Processing.ipynb +++ b/src/primaite/notebooks/Terminal-Processing.ipynb @@ -18,6 +18,15 @@ "The Terminal service comes pre-installed on most Nodes (The exception being Switches, as these are currently dumb). " ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite setup" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb b/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb index 19e95a95..76cab86a 100644 --- a/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb +++ b/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb @@ -18,6 +18,15 @@ "#### First, Import packages and read our config file." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite setup" + ] + }, { "cell_type": "code", "execution_count": null, @@ -32,8 +41,6 @@ "from ray.rllib.algorithms.ppo import PPOConfig\n", "from primaite.session.ray_envs import PrimaiteRayMARLEnv\n", "\n", - "# If you get an error saying this config file doesn't exist, you may need to run `primaite setup` in your command line\n", - "# to copy the files to your user data path.\n", "with open(PRIMAITE_PATHS.user_config_path / 'example_config/data_manipulation_marl.yaml', 'r') as f:\n", " cfg = yaml.safe_load(f)\n", "\n", diff --git a/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb b/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb index 0fd212f2..7252b046 100644 --- a/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb +++ b/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb @@ -11,6 +11,15 @@ "This notebook will demonstrate how to use PrimaiteRayEnv to train a basic PPO agent." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite setup" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/src/primaite/notebooks/Training-an-SB3-Agent.ipynb b/src/primaite/notebooks/Training-an-SB3-Agent.ipynb index 5255b0ad..2b554475 100644 --- a/src/primaite/notebooks/Training-an-SB3-Agent.ipynb +++ b/src/primaite/notebooks/Training-an-SB3-Agent.ipynb @@ -18,6 +18,15 @@ "#### First, we import the inital packages and read in our configuration file." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite setup" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/src/primaite/notebooks/_package_data/uc2_attack.png b/src/primaite/notebooks/_package_data/uc2_attack.png index 8b8df5ce..03797d00 100644 Binary files a/src/primaite/notebooks/_package_data/uc2_attack.png and b/src/primaite/notebooks/_package_data/uc2_attack.png differ diff --git a/src/primaite/notebooks/_package_data/uc2_network.png b/src/primaite/notebooks/_package_data/uc2_network.png index 20fa43c9..10989201 100644 Binary files a/src/primaite/notebooks/_package_data/uc2_network.png and b/src/primaite/notebooks/_package_data/uc2_network.png differ diff --git a/src/primaite/notebooks/multi-processing.ipynb b/src/primaite/notebooks/multi-processing.ipynb index 305cfd70..ad386f34 100644 --- a/src/primaite/notebooks/multi-processing.ipynb +++ b/src/primaite/notebooks/multi-processing.ipynb @@ -18,6 +18,15 @@ "Import packages and read config file." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite setup" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py index 8ff4b6fb..54e649f2 100644 --- a/src/primaite/simulator/file_system/file_system.py +++ b/src/primaite/simulator/file_system/file_system.py @@ -30,6 +30,11 @@ class FileSystem(SimComponent): num_file_deletions: int = 0 "Number of file deletions in the current step." + _default_folder_scan_duration: Optional[int] = None + "Override default scan duration for folders" + _default_folder_restore_duration: Optional[int] = None + "Override default restore duration for folders" + def __init__(self, **kwargs): super().__init__(**kwargs) # Ensure a default root folder @@ -258,6 +263,11 @@ class FileSystem(SimComponent): name=folder.name, request_type=RequestType(func=folder._request_manager) ) self.folders[folder.uuid] = folder + # set the folder scan and restore durations. + if self._default_folder_scan_duration is not None: + folder.scan_duration = self._default_folder_scan_duration + if self._default_folder_restore_duration is not None: + folder.restore_duration = self._default_folder_restore_duration return folder def delete_folder(self, folder_name: str) -> bool: diff --git a/src/primaite/simulator/file_system/file_system_item_abc.py b/src/primaite/simulator/file_system/file_system_item_abc.py index 48b95d20..db51924c 100644 --- a/src/primaite/simulator/file_system/file_system_item_abc.py +++ b/src/primaite/simulator/file_system/file_system_item_abc.py @@ -43,6 +43,9 @@ def convert_size(size_bytes: int) -> str: class FileSystemItemHealthStatus(Enum): """Status of the FileSystemItem.""" + NONE = 0 + """File system item health status is not known.""" + GOOD = 1 """File/Folder is OK.""" @@ -72,7 +75,7 @@ class FileSystemItemABC(SimComponent): health_status: FileSystemItemHealthStatus = FileSystemItemHealthStatus.GOOD "Actual status of the current FileSystemItem" - visible_health_status: FileSystemItemHealthStatus = FileSystemItemHealthStatus.GOOD + visible_health_status: FileSystemItemHealthStatus = FileSystemItemHealthStatus.NONE "Visible status of the current FileSystemItem" previous_hash: Optional[str] = None diff --git a/src/primaite/simulator/file_system/folder.py b/src/primaite/simulator/file_system/folder.py index 78dba4e6..3dd5d1ce 100644 --- a/src/primaite/simulator/file_system/folder.py +++ b/src/primaite/simulator/file_system/folder.py @@ -46,7 +46,7 @@ class Folder(FileSystemItemABC): :param sys_log: The SysLog instance to us to create system logs. """ super().__init__(**kwargs) - + self._scanned_this_step: bool = False self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})") def _init_request_manager(self) -> RequestManager: @@ -83,6 +83,7 @@ class Folder(FileSystemItemABC): state = super().describe_state() state["files"] = {file.name: file.describe_state() for uuid, file in self.files.items()} state["deleted_files"] = {file.name: file.describe_state() for uuid, file in self.deleted_files.items()} + state["scanned_this_step"] = self._scanned_this_step return state def show(self, markdown: bool = False): @@ -135,7 +136,7 @@ class Folder(FileSystemItemABC): def pre_timestep(self, timestep: int) -> None: """Apply pre-timestep logic.""" super().pre_timestep(timestep) - + self._scanned_this_step = False for file in self.files.values(): file.pre_timestep(timestep) @@ -148,9 +149,17 @@ class Folder(FileSystemItemABC): for file_id in self.files: file = self.get_file_by_id(file_uuid=file_id) file.scan() - if file.visible_health_status == FileSystemItemHealthStatus.CORRUPT: - self.health_status = FileSystemItemHealthStatus.CORRUPT + # set folder health to worst file's health by generating a list of file healths. If no files, use 0 + self.health_status = FileSystemItemHealthStatus( + max( + [f.health_status.value for f in self.files.values()] + or [ + 0, + ] + ) + ) self.visible_health_status = self.health_status + self._scanned_this_step = True def _reveal_to_red_timestep(self) -> None: """Apply reveal to red timestep.""" diff --git a/src/primaite/simulator/system/services/ftp/ftp_client.py b/src/primaite/simulator/system/services/ftp/ftp_client.py index 16cefdd6..82875b97 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_client.py +++ b/src/primaite/simulator/system/services/ftp/ftp_client.py @@ -118,6 +118,7 @@ class FTPClient(FTPServiceABC, identifier="FTPClient"): session_id: Optional[str] = None, is_reattempt: Optional[bool] = False, ) -> bool: + self._active = True """ Connects the client to a given FTP server. @@ -174,6 +175,7 @@ class FTPClient(FTPServiceABC, identifier="FTPClient"): :param: is_reattempt: Set to True if attempt to disconnect from FTP Server has been attempted. Default False. :type: is_reattempt: Optional[bool] """ + self._active = True # send a disconnect request payload to FTP server payload: FTPPacket = FTPPacket(ftp_command=FTPCommand.QUIT) software_manager: SoftwareManager = self.software_manager @@ -219,6 +221,7 @@ class FTPClient(FTPServiceABC, identifier="FTPClient"): :param: session_id: The id of the session :type: session_id: Optional[str] """ + self._active = True # check if the file to transfer exists on the client file_to_transfer: File = self.file_system.get_file(folder_name=src_folder_name, file_name=src_file_name) if not file_to_transfer: @@ -276,6 +279,7 @@ class FTPClient(FTPServiceABC, identifier="FTPClient"): :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port["FTP"]. :type: dest_port: Optional[int] """ + self._active = True # check if FTP is currently connected to IP self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port) @@ -327,6 +331,7 @@ class FTPClient(FTPServiceABC, identifier="FTPClient"): This helps prevent an FTP request loop - FTP client and servers can exist on the same node. """ + self._active = True if not self._can_perform_action(): return False diff --git a/src/primaite/simulator/system/services/ftp/ftp_service.py b/src/primaite/simulator/system/services/ftp/ftp_service.py index 52f451e1..13acda70 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_service.py +++ b/src/primaite/simulator/system/services/ftp/ftp_service.py @@ -3,9 +3,11 @@ from abc import ABC from ipaddress import IPv4Address from typing import Dict, Optional +from pydantic import StrictBool + from primaite.simulator.file_system.file_system import File from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode -from primaite.simulator.system.services.service import Service +from primaite.simulator.system.services.service import Service, ServiceOperatingState from primaite.utils.validation.port import Port @@ -16,9 +18,22 @@ class FTPServiceABC(Service, ABC): Contains shared methods between both classes. """ + _active: StrictBool = False + """Flag that is True on timesteps where service transmits data and False when idle. Used for describe_state.""" + + def pre_timestep(self, timestep: int) -> None: + """When a new timestep begins, clear the _active attribute.""" + self._active = False + return super().pre_timestep(timestep) + def describe_state(self) -> Dict: """Returns a Dict of the FTPService state.""" - return super().describe_state() + state = super().describe_state() + + # override so that the service is shows as running only if actively transmitting data this timestep + if self.operating_state == ServiceOperatingState.RUNNING and not self._active: + state["operating_state"] = ServiceOperatingState.STOPPED.value + return state def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket: """ @@ -29,6 +44,7 @@ class FTPServiceABC(Service, ABC): :param: session_id: session ID linked to the FTP Packet. Optional. :type: session_id: Optional[str] """ + self._active = True if payload.ftp_command is not None: self.sys_log.info(f"Received FTP {payload.ftp_command.name} command.") @@ -51,6 +67,7 @@ class FTPServiceABC(Service, ABC): :param: payload: The FTP Packet that contains the file data :type: FTPPacket """ + self._active = True try: file_name = payload.ftp_command_args["dest_file_name"] folder_name = payload.ftp_command_args["dest_folder_name"] @@ -106,6 +123,7 @@ class FTPServiceABC(Service, ABC): :param: is_response: is true if the data being sent is in response to a request. Default False. :type: is_response: bool """ + self._active = True # send STOR request payload: FTPPacket = FTPPacket( ftp_command=FTPCommand.STOR, @@ -135,6 +153,7 @@ class FTPServiceABC(Service, ABC): :param: payload: The FTP Packet that contains the file data :type: FTPPacket """ + self._active = True try: # find the file file_name = payload.ftp_command_args["src_file_name"] @@ -181,6 +200,7 @@ class FTPServiceABC(Service, ABC): :return: True if successful, False otherwise. """ + self._active = True self.sys_log.info(f"{self.name}: Sending FTP {payload.ftp_command.name} {payload.ftp_command_args}") return super().send( diff --git a/tests/assets/configs/action_penalty.yaml b/tests/assets/configs/action_penalty.yaml index 9ab13036..3e57f579 100644 --- a/tests/assets/configs/action_penalty.yaml +++ b/tests/assets/configs/action_penalty.yaml @@ -69,8 +69,8 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml index 1cd0883c..9cf95a64 100644 --- a/tests/assets/configs/bad_primaite_session.yaml +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -74,8 +74,8 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP diff --git a/tests/assets/configs/basic_switched_network.yaml b/tests/assets/configs/basic_switched_network.yaml index 10a92d7a..a39bf876 100644 --- a/tests/assets/configs/basic_switched_network.yaml +++ b/tests/assets/configs/basic_switched_network.yaml @@ -88,8 +88,8 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP diff --git a/tests/assets/configs/data_manipulation.yaml b/tests/assets/configs/data_manipulation.yaml index 328fe413..726c9ab0 100644 --- a/tests/assets/configs/data_manipulation.yaml +++ b/tests/assets/configs/data_manipulation.yaml @@ -160,8 +160,8 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP diff --git a/tests/assets/configs/eval_only_primaite_session.yaml b/tests/assets/configs/eval_only_primaite_session.yaml index e277a881..41b7fce9 100644 --- a/tests/assets/configs/eval_only_primaite_session.yaml +++ b/tests/assets/configs/eval_only_primaite_session.yaml @@ -102,8 +102,8 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP diff --git a/tests/assets/configs/extended_config.yaml b/tests/assets/configs/extended_config.yaml index 0ec0c91f..bff58ebd 100644 --- a/tests/assets/configs/extended_config.yaml +++ b/tests/assets/configs/extended_config.yaml @@ -161,8 +161,8 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP diff --git a/tests/assets/configs/firewall_actions_network.yaml b/tests/assets/configs/firewall_actions_network.yaml index 6b454a12..4b11dbcc 100644 --- a/tests/assets/configs/firewall_actions_network.yaml +++ b/tests/assets/configs/firewall_actions_network.yaml @@ -77,8 +77,8 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP diff --git a/tests/assets/configs/fixing_duration_one_item.yaml b/tests/assets/configs/fixing_duration_one_item.yaml index 02aa8e4b..da5a9993 100644 --- a/tests/assets/configs/fixing_duration_one_item.yaml +++ b/tests/assets/configs/fixing_duration_one_item.yaml @@ -81,8 +81,8 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP diff --git a/tests/assets/configs/multi_agent_session.yaml b/tests/assets/configs/multi_agent_session.yaml index 3b746273..93baf4af 100644 --- a/tests/assets/configs/multi_agent_session.yaml +++ b/tests/assets/configs/multi_agent_session.yaml @@ -152,8 +152,8 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP @@ -666,8 +666,8 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP diff --git a/tests/assets/configs/shared_rewards.yaml b/tests/assets/configs/shared_rewards.yaml index 7ad5371d..d5615a72 100644 --- a/tests/assets/configs/shared_rewards.yaml +++ b/tests/assets/configs/shared_rewards.yaml @@ -151,8 +151,8 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP diff --git a/tests/assets/configs/software_fixing_duration.yaml b/tests/assets/configs/software_fixing_duration.yaml index 073a5f83..f685b420 100644 --- a/tests/assets/configs/software_fixing_duration.yaml +++ b/tests/assets/configs/software_fixing_duration.yaml @@ -81,8 +81,8 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP diff --git a/tests/assets/configs/test_application_install.yaml b/tests/assets/configs/test_application_install.yaml index cafcc72b..25bc38e6 100644 --- a/tests/assets/configs/test_application_install.yaml +++ b/tests/assets/configs/test_application_install.yaml @@ -155,8 +155,8 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index cd5d08d3..2d124981 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -103,8 +103,8 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP diff --git a/tests/integration_tests/game_layer/actions/test_file_request_permission.py b/tests/integration_tests/game_layer/actions/test_file_request_permission.py index 0976abdc..cab80434 100644 --- a/tests/integration_tests/game_layer/actions/test_file_request_permission.py +++ b/tests/integration_tests/game_layer/actions/test_file_request_permission.py @@ -69,7 +69,7 @@ def test_file_scan_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent file.corrupt() assert file.health_status == FileSystemItemHealthStatus.CORRUPT - assert file.visible_health_status == FileSystemItemHealthStatus.GOOD + assert file.visible_health_status == FileSystemItemHealthStatus.NONE action = ( "node_file_scan", diff --git a/tests/integration_tests/game_layer/actions/test_folder_request_permission.py b/tests/integration_tests/game_layer/actions/test_folder_request_permission.py index 9cd4bfcf..207f7d48 100644 --- a/tests/integration_tests/game_layer/actions/test_folder_request_permission.py +++ b/tests/integration_tests/game_layer/actions/test_folder_request_permission.py @@ -52,12 +52,12 @@ def test_folder_scan_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAge folder = client_1.file_system.get_folder(folder_name="downloads") assert folder.health_status == FileSystemItemHealthStatus.GOOD - assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD + assert folder.visible_health_status == FileSystemItemHealthStatus.NONE folder.corrupt() assert folder.health_status == FileSystemItemHealthStatus.CORRUPT - assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD + assert folder.visible_health_status == FileSystemItemHealthStatus.NONE action = ( "node_folder_scan", diff --git a/tests/integration_tests/game_layer/observations/test_file_system_observations.py b/tests/integration_tests/game_layer/observations/test_file_system_observations.py index 0268cb95..19c0c4bc 100644 --- a/tests/integration_tests/game_layer/observations/test_file_system_observations.py +++ b/tests/integration_tests/game_layer/observations/test_file_system_observations.py @@ -32,11 +32,11 @@ def test_file_observation(simulation): assert dog_file_obs.space["health_status"] == spaces.Discrete(6) observation_state = dog_file_obs.observe(simulation.describe_state()) - assert observation_state.get("health_status") == 1 # good initial + assert observation_state.get("health_status") == 0 # initially unset file.corrupt() observation_state = dog_file_obs.observe(simulation.describe_state()) - assert observation_state.get("health_status") == 1 # scan file so this changes + assert observation_state.get("health_status") == 0 # still default unset value because no scan happened file.scan() file.apply_timestep(0) # apply time step @@ -63,11 +63,11 @@ def test_folder_observation(simulation): observation_state = root_folder_obs.observe(simulation.describe_state()) assert observation_state.get("FILES") is not None - assert observation_state.get("health_status") == 1 + assert observation_state.get("health_status") == 0 # initially unset file.corrupt() # corrupt just the file observation_state = root_folder_obs.observe(simulation.describe_state()) - assert observation_state.get("health_status") == 1 # scan folder to change this + assert observation_state.get("health_status") == 0 # still unset as no scan occurred yet folder.scan() for i in range(folder.scan_duration + 1): diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index 800549bc..5a308cf8 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -275,7 +275,7 @@ def test_node_file_scan_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAge client_1 = game.simulation.network.get_node_by_hostname("client_1") file = client_1.file_system.get_file("downloads", "cat.png") assert file.health_status == FileSystemItemHealthStatus.GOOD - assert file.visible_health_status == FileSystemItemHealthStatus.GOOD + assert file.visible_health_status == FileSystemItemHealthStatus.NONE # 2: perform a scan and make sure nothing has changed action = ( diff --git a/tests/integration_tests/game_layer/test_observations.py b/tests/integration_tests/game_layer/test_observations.py index 23364f13..5afad296 100644 --- a/tests/integration_tests/game_layer/test_observations.py +++ b/tests/integration_tests/game_layer/test_observations.py @@ -17,12 +17,7 @@ def test_file_observation(): dog_file_obs = FileObservation( where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"], include_num_access=False, - file_system_requires_scan=True, + file_system_requires_scan=False, ) assert dog_file_obs.observe(state) == {"health_status": 1} assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)}) - - -# TODO: -# def test_file_num_access(): -# ... diff --git a/tests/integration_tests/system/test_database_on_node.py b/tests/integration_tests/system/test_database_on_node.py index 31732f77..bb25f8c8 100644 --- a/tests/integration_tests/system/test_database_on_node.py +++ b/tests/integration_tests/system/test_database_on_node.py @@ -163,7 +163,7 @@ def test_restore_backup_without_updating_scan(uc2_network): db_service.db_file.corrupt() # corrupt the db assert db_service.db_file.health_status == FileSystemItemHealthStatus.CORRUPT # db file is actually corrupt - assert db_service.db_file.visible_health_status == FileSystemItemHealthStatus.GOOD # not scanned yet + assert db_service.db_file.visible_health_status == FileSystemItemHealthStatus.NONE # not scanned yet db_service.db_file.scan() # scan the db file @@ -190,7 +190,7 @@ def test_restore_backup_after_deleting_file_without_updating_scan(uc2_network): db_service.db_file.corrupt() # corrupt the db assert db_service.db_file.health_status == FileSystemItemHealthStatus.CORRUPT # db file is actually corrupt - assert db_service.db_file.visible_health_status == FileSystemItemHealthStatus.GOOD # not scanned yet + assert db_service.db_file.visible_health_status == FileSystemItemHealthStatus.NONE # not scanned yet db_service.db_file.scan() # scan the db file diff --git a/tests/unit_tests/_primaite/_game/_agent/test_observations.py b/tests/unit_tests/_primaite/_game/_agent/test_observations.py index 5170bcf3..5156a29f 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_observations.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_observations.py @@ -69,8 +69,8 @@ class TestFileSystemRequiresScan: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP @@ -119,14 +119,20 @@ class TestFileSystemRequiresScan: assert obs_not_requiring_scan.observe(file_state)["health_status"] == 3 def test_folder_require_scan(self): - folder_state = {"health_status": 3, "visible_status": 1} + folder_state = {"health_status": 3, "visible_status": 1, "scanned_this_step": False} obs_requiring_scan = FolderObservation( [], files=[], num_files=0, include_num_access=False, file_system_requires_scan=True ) - assert obs_requiring_scan.observe(folder_state)["health_status"] == 1 + assert obs_requiring_scan.observe(folder_state)["health_status"] == 0 obs_not_requiring_scan = FolderObservation( [], files=[], num_files=0, include_num_access=False, file_system_requires_scan=False ) assert obs_not_requiring_scan.observe(folder_state)["health_status"] == 3 + + folder_state = {"health_status": 3, "visible_status": 1, "scanned_this_step": True} + obs_requiring_scan = FolderObservation( + [], files=[], num_files=0, include_num_access=False, file_system_requires_scan=True + ) + assert obs_requiring_scan.observe(folder_state)["health_status"] == 1 diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file.py index 9cacdccf..9691080d 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_file.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_file.py @@ -22,12 +22,12 @@ def test_file_scan(file_system): file: File = file_system.create_file(file_name="test_file.txt", folder_name="test_folder") assert file.health_status == FileSystemItemHealthStatus.GOOD - assert file.visible_health_status == FileSystemItemHealthStatus.GOOD + assert file.visible_health_status == FileSystemItemHealthStatus.NONE file.corrupt() assert file.health_status == FileSystemItemHealthStatus.CORRUPT - assert file.visible_health_status == FileSystemItemHealthStatus.GOOD + assert file.visible_health_status == FileSystemItemHealthStatus.NONE file.scan() diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_actions.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_actions.py index 2729e5e4..59f3f000 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_actions.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_actions.py @@ -24,7 +24,7 @@ def test_file_scan_request(populated_file_system): file.corrupt() assert file.health_status == FileSystemItemHealthStatus.CORRUPT - assert file.visible_health_status == FileSystemItemHealthStatus.GOOD + assert file.visible_health_status == FileSystemItemHealthStatus.NONE fs.apply_request(request=["folder", folder.name, "file", file.name, "scan"]) @@ -94,7 +94,7 @@ def test_deleted_file_cannot_be_interacted_with(populated_file_system): assert fs.get_file(folder_name=folder.name, file_name=file.name).health_status == FileSystemItemHealthStatus.CORRUPT assert ( fs.get_file(folder_name=folder.name, file_name=file.name).visible_health_status - == FileSystemItemHealthStatus.GOOD + == FileSystemItemHealthStatus.NONE ) fs.apply_request(request=["delete", "file", folder.name, file.name]) diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_folder.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_folder.py index 10393c6c..b5d9b269 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_folder.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_folder.py @@ -44,25 +44,25 @@ def test_folder_scan(file_system): file2: File = folder.get_file_by_id(file_uuid=list(folder.files)[0]) assert folder.health_status == FileSystemItemHealthStatus.GOOD - assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD - assert file1.visible_health_status == FileSystemItemHealthStatus.GOOD - assert file2.visible_health_status == FileSystemItemHealthStatus.GOOD + assert folder.visible_health_status == FileSystemItemHealthStatus.NONE + assert file1.visible_health_status == FileSystemItemHealthStatus.NONE + assert file2.visible_health_status == FileSystemItemHealthStatus.NONE folder.corrupt() assert folder.health_status == FileSystemItemHealthStatus.CORRUPT - assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD - assert file1.visible_health_status == FileSystemItemHealthStatus.GOOD - assert file2.visible_health_status == FileSystemItemHealthStatus.GOOD + assert folder.visible_health_status == FileSystemItemHealthStatus.NONE + assert file1.visible_health_status == FileSystemItemHealthStatus.NONE + assert file2.visible_health_status == FileSystemItemHealthStatus.NONE folder.scan() folder.apply_timestep(timestep=0) assert folder.health_status == FileSystemItemHealthStatus.CORRUPT - assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD - assert file1.visible_health_status == FileSystemItemHealthStatus.GOOD - assert file2.visible_health_status == FileSystemItemHealthStatus.GOOD + assert folder.visible_health_status == FileSystemItemHealthStatus.NONE + assert file1.visible_health_status == FileSystemItemHealthStatus.NONE + assert file2.visible_health_status == FileSystemItemHealthStatus.NONE folder.apply_timestep(timestep=1) folder.apply_timestep(timestep=2) diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_folder_actions.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_folder_actions.py index 07c1ec46..72857638 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_folder_actions.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_folder_actions.py @@ -29,18 +29,18 @@ def test_folder_scan_request(populated_file_system): folder.corrupt() assert folder.health_status == FileSystemItemHealthStatus.CORRUPT - assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD - assert file1.visible_health_status == FileSystemItemHealthStatus.GOOD - assert file2.visible_health_status == FileSystemItemHealthStatus.GOOD + assert folder.visible_health_status == FileSystemItemHealthStatus.NONE + assert file1.visible_health_status == FileSystemItemHealthStatus.NONE + assert file2.visible_health_status == FileSystemItemHealthStatus.NONE fs.apply_request(request=["folder", folder.name, "scan"]) folder.apply_timestep(timestep=0) assert folder.health_status == FileSystemItemHealthStatus.CORRUPT - assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD - assert file1.visible_health_status == FileSystemItemHealthStatus.GOOD - assert file2.visible_health_status == FileSystemItemHealthStatus.GOOD + assert folder.visible_health_status == FileSystemItemHealthStatus.NONE + assert file1.visible_health_status == FileSystemItemHealthStatus.NONE + assert file2.visible_health_status == FileSystemItemHealthStatus.NONE folder.apply_timestep(timestep=1) folder.apply_timestep(timestep=2) diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py index 605f8c3b..672a4b5f 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py @@ -70,13 +70,13 @@ def test_node_os_scan(node): # add folder and file to node folder: Folder = node.file_system.create_folder(folder_name="test_folder") folder.corrupt() - assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD + assert folder.visible_health_status == FileSystemItemHealthStatus.NONE file: File = node.file_system.create_file(folder_name="test_folder", file_name="file.txt") file2: File = node.file_system.create_file(folder_name="test_folder", file_name="file2.txt") file.corrupt() file2.corrupt() - assert file.visible_health_status == FileSystemItemHealthStatus.GOOD + assert file.visible_health_status == FileSystemItemHealthStatus.NONE # run os scan node.apply_request(["os", "scan"])