From a4caa3dfe4abbff8e9b88c8114f93d384d0927af Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 3 Apr 2024 15:58:01 +0100 Subject: [PATCH] #2449 fix observation integration --- .../game/agent/observations/file_system_observations.py | 2 +- .../game/agent/observations/firewall_observation.py | 2 +- src/primaite/game/agent/observations/host_observations.py | 2 +- src/primaite/game/agent/observations/node_observations.py | 2 +- src/primaite/game/agent/observations/router_observation.py | 2 +- .../notebooks/Data-Manipulation-E2E-Demonstration.ipynb | 6 +++--- 6 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py index 9b9434af..3e262055 100644 --- a/src/primaite/game/agent/observations/file_system_observations.py +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -205,7 +205,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): :return: Constructed folder observation instance. :rtype: FolderObservation """ - where = parent_where + ["folders", config.folder_name] + where = parent_where + ["file_system", "folders", config.folder_name] # pass down shared/common config items for file_config in config.files: diff --git a/src/primaite/game/agent/observations/firewall_observation.py b/src/primaite/game/agent/observations/firewall_observation.py index 0c10a8d2..0a1498b1 100644 --- a/src/primaite/game/agent/observations/firewall_observation.py +++ b/src/primaite/game/agent/observations/firewall_observation.py @@ -215,7 +215,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): :rtype: FirewallObservation """ return cls( - where=parent_where + ["nodes", config.hostname], + where=parent_where + [config.hostname], ip_list=config.ip_list, wildcard_list=config.wildcard_list, port_list=config.port_list, diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 8ea40be7..6dbde789 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -216,7 +216,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): if parent_where == []: where = ["network", "nodes", config.hostname] else: - where = parent_where + ["nodes", config.hostname] + where = parent_where + [config.hostname] # Pass down shared/common config items for folder_config in config.folders: diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index dce33a04..f11ffebf 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -164,7 +164,7 @@ class NodesObservation(AbstractObservation, identifier="NODES"): :return: Constructed nodes observation instance. :rtype: NodesObservation """ - if parent_where is None: + if not parent_where: where = ["network", "nodes"] else: where = parent_where + ["nodes"] diff --git a/src/primaite/game/agent/observations/router_observation.py b/src/primaite/game/agent/observations/router_observation.py index a7879f09..aeac2766 100644 --- a/src/primaite/game/agent/observations/router_observation.py +++ b/src/primaite/game/agent/observations/router_observation.py @@ -124,7 +124,7 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): :return: Constructed router observation instance. :rtype: RouterObservation """ - where = parent_where + ["nodes", config.hostname] + where = parent_where + [config.hostname] if config.acl is None: config.acl = ACLObservation.ConfigSchema() diff --git a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb index 60d40f9c..a958aa0a 100644 --- a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb @@ -592,7 +592,7 @@ "metadata": {}, "outputs": [], "source": [ - "obs['ACL']" + "obs['NODES']['ROUTER0']" ] }, { @@ -616,12 +616,12 @@ " tries += 1\n", " obs, reward, terminated, truncated, info = env.step(0)\n", "\n", - " if obs['NODES'][6]['NICS'][1]['NMNE']['outbound'] == 1:\n", + " if obs['NODES']['HOST5']['NICS'][1]['NMNE']['outbound'] == 1:\n", " # client 1 has NMNEs, let's block it\n", " obs, reward, terminated, truncated, info = env.step(50) # block client 1\n", " print(\"blocking client 1\")\n", " break\n", - " elif obs['NODES'][7]['NICS'][1]['NMNE']['outbound'] == 1:\n", + " elif obs['NODES']['HOST6']['NICS'][1]['NMNE']['outbound'] == 1:\n", " # client 2 has NMNEs, so let's block it\n", " obs, reward, terminated, truncated, info = env.step(51) # block client 2\n", " print(\"blocking client 2\")\n",