From ccb36f84004134fb361cf7df94806a3a7c099851 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Sun, 8 Oct 2023 17:02:54 +0100 Subject: [PATCH] Change observations to make loading from config better --- example_config.yaml | 41 ++++----- sandbox.ipynb | 112 +++++++++++++++++++++--- src/primaite/game/agent/interface.py | 2 + src/primaite/game/agent/observations.py | 109 +++++++++++++++++++++-- src/primaite/game/session.py | 32 +++++-- 5 files changed, 246 insertions(+), 50 deletions(-) diff --git a/example_config.yaml b/example_config.yaml index 9c75c92e..9f679223 100644 --- a/example_config.yaml +++ b/example_config.yaml @@ -17,10 +17,12 @@ game_config: - ref: client_1_green_user team: GREEN type: GreenWebBrowsingAgent - observation_space: null + observation_space: + type: UC2GreenObservation action_space: action_list: - type: DONOTHING + # # - type: NODE_LOGON # - type: NODE_LOGOFF # - type: NODE_APPLICATION_EXECUTE @@ -68,6 +70,7 @@ game_config: action_space: action_list: - type: DONOTHING + #8f:1d:3f:32:1c:d6\n", + "2023-10-08 14:40:34,938: SwitchPort 8f:1d:3f:32:1c:d6 connected to Link be:b1:a2:ce:eb:4c/192.168.1.1<-->8f:1d:3f:32:1c:d6\n", + "2023-10-08 14:40:34,939: Link be:b1:a2:ce:eb:4c/192.168.1.1<-->8f:1d:3f:32:1c:d6 up\n", + "2023-10-08 14:40:34,939: Link be:b1:a2:ce:eb:4c/192.168.1.1<-->8f:1d:3f:32:1c:d6 up\n", + "2023-10-08 14:40:34,940: Added link b8070c26-6ad0-4d7e-aed8-c1bcdcf9b438 to connect be:b1:a2:ce:eb:4c/192.168.1.1 and 8f:1d:3f:32:1c:d6\n", + "2023-10-08 14:40:34,942: NIC dc:48:6c:bd:8b:b1/192.168.1.1 connected to Link dc:48:6c:bd:8b:b1/192.168.1.1<-->b2:14:a5:82:c0:7a\n", + "2023-10-08 14:40:34,943: SwitchPort b2:14:a5:82:c0:7a connected to Link dc:48:6c:bd:8b:b1/192.168.1.1<-->b2:14:a5:82:c0:7a\n", + "2023-10-08 14:40:34,945: Link dc:48:6c:bd:8b:b1/192.168.1.1<-->b2:14:a5:82:c0:7a up\n", + "2023-10-08 14:40:34,946: Link dc:48:6c:bd:8b:b1/192.168.1.1<-->b2:14:a5:82:c0:7a up\n", + "2023-10-08 14:40:34,946: Added link 102f5506-a939-4af7-8ebb-8e173e18283c to connect dc:48:6c:bd:8b:b1/192.168.1.1 and b2:14:a5:82:c0:7a\n", + "2023-10-08 14:40:34,947: SwitchPort 00:9f:54:21:e2:f2 connected to Link 00:9f:54:21:e2:f2<-->68:69:bf:51:6c:c0/192.168.1.10\n", + "2023-10-08 14:40:34,949: Link 00:9f:54:21:e2:f2<-->68:69:bf:51:6c:c0/192.168.1.10 up\n", + "2023-10-08 14:40:34,950: NIC 68:69:bf:51:6c:c0/192.168.1.10 connected to Link 00:9f:54:21:e2:f2<-->68:69:bf:51:6c:c0/192.168.1.10\n", + "2023-10-08 14:40:34,951: Link 00:9f:54:21:e2:f2<-->68:69:bf:51:6c:c0/192.168.1.10 up\n", + "2023-10-08 14:40:34,952: Added link 6136fd05-7a16-4afd-aebd-cdf6e255689b to connect 00:9f:54:21:e2:f2 and 68:69:bf:51:6c:c0/192.168.1.10\n", + "2023-10-08 14:40:34,952: SwitchPort 48:cc:7b:ac:dd:f9 connected to Link 48:cc:7b:ac:dd:f9<-->64:15:7d:f0:cd:ce/192.168.1.12\n", + "2023-10-08 14:40:34,954: Link 48:cc:7b:ac:dd:f9<-->64:15:7d:f0:cd:ce/192.168.1.12 up\n", + "2023-10-08 14:40:34,954: NIC 64:15:7d:f0:cd:ce/192.168.1.12 connected to Link 48:cc:7b:ac:dd:f9<-->64:15:7d:f0:cd:ce/192.168.1.12\n", + "2023-10-08 14:40:34,955: Link 48:cc:7b:ac:dd:f9<-->64:15:7d:f0:cd:ce/192.168.1.12 up\n", + "2023-10-08 14:40:34,956: Added link 02c6f4e4-3674-4189-a5a1-334fa86921f6 to connect 48:cc:7b:ac:dd:f9 and 64:15:7d:f0:cd:ce/192.168.1.12\n", + "2023-10-08 14:40:34,957: SwitchPort e4:e3:bb:bf:9e:04 connected to Link e4:e3:bb:bf:9e:04<-->81:cd:6e:b8:3d:6c/192.168.1.14\n", + "2023-10-08 14:40:34,958: Link e4:e3:bb:bf:9e:04<-->81:cd:6e:b8:3d:6c/192.168.1.14 up\n", + "2023-10-08 14:40:34,959: NIC 81:cd:6e:b8:3d:6c/192.168.1.14 connected to Link e4:e3:bb:bf:9e:04<-->81:cd:6e:b8:3d:6c/192.168.1.14\n", + "2023-10-08 14:40:34,960: Link e4:e3:bb:bf:9e:04<-->81:cd:6e:b8:3d:6c/192.168.1.14 up\n", + "2023-10-08 14:40:34,961: Added link 57e0f89d-265b-4d27-838b-828ae9800688 to connect e4:e3:bb:bf:9e:04 and 81:cd:6e:b8:3d:6c/192.168.1.14\n", + "2023-10-08 14:40:34,962: SwitchPort 71:5f:fc:32:79:9f connected to Link 71:5f:fc:32:79:9f<-->29:fa:41:0b:f5:1b/192.168.1.16\n", + "2023-10-08 14:40:34,964: Link 71:5f:fc:32:79:9f<-->29:fa:41:0b:f5:1b/192.168.1.16 up\n", + "2023-10-08 14:40:34,965: NIC 29:fa:41:0b:f5:1b/192.168.1.16 connected to Link 71:5f:fc:32:79:9f<-->29:fa:41:0b:f5:1b/192.168.1.16\n", + "2023-10-08 14:40:34,966: Link 71:5f:fc:32:79:9f<-->29:fa:41:0b:f5:1b/192.168.1.16 up\n", + "2023-10-08 14:40:34,967: Added link 1f382171-5e0d-4a76-9500-27dc68c3c7ee to connect 71:5f:fc:32:79:9f and 29:fa:41:0b:f5:1b/192.168.1.16\n", + "2023-10-08 14:40:34,968: SwitchPort 66:5d:d0:ba:c1:91 connected to Link 66:5d:d0:ba:c1:91<-->0d:22:07:53:7a:e1/192.168.1.110\n", + "2023-10-08 14:40:34,969: Link 66:5d:d0:ba:c1:91<-->0d:22:07:53:7a:e1/192.168.1.110 up\n", + "2023-10-08 14:40:34,970: NIC 0d:22:07:53:7a:e1/192.168.1.110 connected to Link 66:5d:d0:ba:c1:91<-->0d:22:07:53:7a:e1/192.168.1.110\n", + "2023-10-08 14:40:34,971: Link 66:5d:d0:ba:c1:91<-->0d:22:07:53:7a:e1/192.168.1.110 up\n", + "2023-10-08 14:40:34,972: Added link d8ea175e-50c8-4597-99bf-ac9001b30c77 to connect 66:5d:d0:ba:c1:91 and 0d:22:07:53:7a:e1/192.168.1.110\n", + "2023-10-08 14:40:34,972: SwitchPort 22:f5:91:5a:bb:b1 connected to Link 22:f5:91:5a:bb:b1<-->82:e5:30:d9:0e:85/192.168.10.21\n", + "2023-10-08 14:40:34,974: Link 22:f5:91:5a:bb:b1<-->82:e5:30:d9:0e:85/192.168.10.21 up\n", + "2023-10-08 14:40:34,975: NIC 82:e5:30:d9:0e:85/192.168.10.21 connected to Link 22:f5:91:5a:bb:b1<-->82:e5:30:d9:0e:85/192.168.10.21\n", + "2023-10-08 14:40:34,976: Link 22:f5:91:5a:bb:b1<-->82:e5:30:d9:0e:85/192.168.10.21 up\n", + "2023-10-08 14:40:34,977: Added link 40ba49b9-e334-45ce-93da-a1459b80e9a2 to connect 22:f5:91:5a:bb:b1 and 82:e5:30:d9:0e:85/192.168.10.21\n", + "2023-10-08 14:40:34,978: SwitchPort 70:77:d0:12:cd:a0 connected to Link 70:77:d0:12:cd:a0<-->ef:20:20:d8:9a:11/192.168.10.22\n", + "2023-10-08 14:40:34,980: Link 70:77:d0:12:cd:a0<-->ef:20:20:d8:9a:11/192.168.10.22 up\n", + "2023-10-08 14:40:34,981: NIC ef:20:20:d8:9a:11/192.168.10.22 connected to Link 70:77:d0:12:cd:a0<-->ef:20:20:d8:9a:11/192.168.10.22\n", + "2023-10-08 14:40:34,982: Link 70:77:d0:12:cd:a0<-->ef:20:20:d8:9a:11/192.168.10.22 up\n", + "2023-10-08 14:40:34,982: Added link c36027fe-052f-4eb6-b6c6-10bf817c7ac9 to connect 70:77:d0:12:cd:a0 and ef:20:20:d8:9a:11/192.168.10.22\n", + "2023-10-08 14:40:34,983: SwitchPort 62:da:0d:de:eb:27 connected to Link 62:da:0d:de:eb:27<-->b8:2b:a3:f0:18:b9/192.168.10.110\n", + "2023-10-08 14:40:34,985: Link 62:da:0d:de:eb:27<-->b8:2b:a3:f0:18:b9/192.168.10.110 up\n", + "2023-10-08 14:40:34,986: NIC b8:2b:a3:f0:18:b9/192.168.10.110 connected to Link 62:da:0d:de:eb:27<-->b8:2b:a3:f0:18:b9/192.168.10.110\n", + "2023-10-08 14:40:34,987: Link 62:da:0d:de:eb:27<-->b8:2b:a3:f0:18:b9/192.168.10.110 up\n", + "2023-10-08 14:40:34,988: Added link 9469edcd-6b36-4333-b948-3eeccf24abcb to connect 62:da:0d:de:eb:27 and b8:2b:a3:f0:18:b9/192.168.10.110\n" ] }, { @@ -93,7 +153,9 @@ { "data": { "text/plain": [ - "[]" + "[,\n", + " ,\n", + " ]" ] }, "execution_count": 7, @@ -118,7 +180,33 @@ "cell_type": "code", "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-10-08 14:40:35,046: Stepping primaite session. Step counter: 0\n", + "2023-10-08 14:40:35,047: Sending simulation state to agent client_1_green_user\n", + "2023-10-08 14:40:35,049: Getting agent action\n", + "2023-10-08 14:40:35,050: Formatting agent action DONOTHING\n", + "2023-10-08 14:40:35,051: Sending request to simulation: ['do_nothing']\n", + "2023-10-08 14:40:35,052: Sending simulation state to agent client_1_data_manipulation_red_bot\n" + ] + }, + { + "ename": "AttributeError", + "evalue": "'NoneType' object has no attribute 'observe'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/cade/repos/PrimAITE/sandbox.ipynb Cell 10\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> 1\u001b[0m sess\u001b[39m.\u001b[39;49mstep()\n", + "File \u001b[0;32m~/repos/PrimAITE/src/primaite/game/session.py:75\u001b[0m, in \u001b[0;36mPrimaiteSession.step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 72\u001b[0m sim_state \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39msimulation\u001b[39m.\u001b[39mdescribe_state()\n\u001b[1;32m 74\u001b[0m \u001b[39m# 6. each agent takes most recent state and converts it to CAOS observation\u001b[39;00m\n\u001b[0;32m---> 75\u001b[0m agent_obs \u001b[39m=\u001b[39m agent\u001b[39m.\u001b[39;49mconvert_state_to_obs(sim_state)\n\u001b[1;32m 77\u001b[0m \u001b[39m# 7. meanwhile each agent also takes state and calculates reward\u001b[39;00m\n\u001b[1;32m 78\u001b[0m agent_reward \u001b[39m=\u001b[39m agent\u001b[39m.\u001b[39mcalculate_reward_from_state(sim_state)\n", + "File \u001b[0;32m~/repos/PrimAITE/src/primaite/game/agent/interface.py:40\u001b[0m, in \u001b[0;36mAbstractAgent.convert_state_to_obs\u001b[0;34m(self, state)\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mconvert_state_to_obs\u001b[39m(\u001b[39mself\u001b[39m, state: Dict) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m ObsType:\n\u001b[1;32m 36\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 37\u001b[0m \u001b[39m state : dict state directly from simulation.describe_state\u001b[39;00m\n\u001b[1;32m 38\u001b[0m \u001b[39m output : dict state according to CAOS.\u001b[39;00m\n\u001b[1;32m 39\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 40\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mobservation_space\u001b[39m.\u001b[39;49mobserve(state)\n", + "\u001b[0;31mAttributeError\u001b[0m: 'NoneType' object has no attribute 'observe'" + ] + } + ], "source": [ "sess.step()" ] diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 4fd52d96..6083db6f 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -18,10 +18,12 @@ class AbstractAgent(ABC): def __init__( self, + agent_name: Optional[str], action_space: Optional[ActionManager], observation_space: Optional[ObservationSpace], reward_function: Optional[RewardFunction], ) -> None: + self.agent_name:str = agent_name or "unnamed_agent" self.action_space: Optional[ActionManager] = action_space self.observation_space: Optional[ObservationSpace] = observation_space self.reward_function: Optional[RewardFunction] = reward_function diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index f919a723..21f623fd 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -1,10 +1,13 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, Hashable, List, Optional +from typing import Any, Dict, Hashable, List, Optional, TYPE_CHECKING from gym import spaces from pydantic import BaseModel +from primaite.game.session import PrimaiteSession from primaite.simulator.sim_container import Simulation +if TYPE_CHECKING: + from primaite.game.session import PrimaiteSession NOT_PRESENT_IN_STATE = object() """ @@ -53,6 +56,15 @@ class AbstractObservation(ABC): """Subclasses must define the shape that they expect""" ... + @abstractmethod + @classmethod + def from_config(cls, config:Dict, session:"PrimaiteSession"): + """Create this observation space component form a serialised format. + + The `session` parameter is for a the PrimaiteSession object that spawns this component. During deserialisation, + a subclass of this class may need to translate from a 'reference' to a UUID. + """ + class FileObservation(AbstractObservation): def __init__(self, where: Optional[List[str]] = None) -> None: @@ -84,6 +96,10 @@ class FileObservation(AbstractObservation): def space(self) -> spaces.Space: return spaces.Dict({"health_status": spaces.Discrete(6)}) + @classmethod + def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where=None): + return cls(where=parent_where+["files", config["file_name"]]) + class ServiceObservation(AbstractObservation): default_observation: spaces.Space = {"operating_status": 0, "health_status": 0} @@ -115,6 +131,11 @@ class ServiceObservation(AbstractObservation): def space(self) -> spaces.Space: return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(6)}) + @classmethod + def from_config(cls, config: Dict, session: PrimaiteSession, parent_where:Optional[List[str]]=None): + return cls(where=parent_where+["services",session.ref_map_services[config['service_ref']]]) + + class LinkObservation(AbstractObservation): default_observation: spaces.Space = {"protocols": {"all": {"load": 0}}} @@ -154,6 +175,10 @@ class LinkObservation(AbstractObservation): def space(self) -> spaces.Space: return spaces.Dict({"protocols": spaces.Dict({"all": spaces.Dict({"load": spaces.Discrete(11)})})}) + @classmethod + def from_config(cls, config: Dict, session: "PrimaiteSession"): + return cls(where=['network','links', session.ref_map_links[config['link_ref']]]) + class FolderObservation(AbstractObservation): def __init__(self, where: Optional[List[str]] = None, files: List[FileObservation] = []) -> None: @@ -209,6 +234,15 @@ class FolderObservation(AbstractObservation): } ) + @classmethod + def from_config(cls, config: Dict, session: PrimaiteSession, parent_where:Optional[List[str]]): + where = parent_where + ["folders", config['folder_name']] + + file_configs = config["files"] + files = [FileObservation.from_config(config=f, session=session, parent_where=where) for f in file_configs] + + return cls(where=where,files=files) + class NicObservation(AbstractObservation): default_observation: spaces.Space = {"nic_status": 0} @@ -230,6 +264,10 @@ class NicObservation(AbstractObservation): def space(self) -> spaces.Space: return spaces.Dict({"nic_status": spaces.Discrete(3)}) + @classmethod + def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where:Optional[List[str]]): + return cls(where=parent_where + ["NICs", config["nic_uuid"]]) + class NodeObservation(AbstractObservation): def __init__( @@ -310,6 +348,25 @@ class NodeObservation(AbstractObservation): return spaces.Dict(space_shape) + @classmethod + def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where:Optional[List[str]]= None): + node_uuid = session.ref_map_nodes[config['node_ref']] + if parent_where is None: + where = ["network", "nodes", node_uuid] + else: + where = parent_where + ["nodes", node_uuid] + + svc_configs = config.get('services', {}) + services = [ServiceObservation.from_config(config=c, session=session, parent_where=where) for c in svc_configs] + folder_configs = config.get('folders', {}) + folders = [FolderObservation.from_config(config=c,session=session, parent_where=where) for c in folder_configs] + nic_uuids = session.simulation.network.nodes[node_uuid].nics.keys() + nic_configs = [{'nic_uuid':n for n in nic_uuids }] + nics = [NicObservation.from_config(config=c, session=session, parent_where=where) for c in nic_configs] + logon_status = config.get('logon_status',False) + cls(where=where, services=services, folders=folders, nics=nics, logon_status=logon_status) + return super().from_config(config, session) + class AclObservation(AbstractObservation): @@ -399,6 +456,21 @@ class AclObservation(AbstractObservation): } ) + @classmethod + def from_config(cls, config: Dict, session: "PrimaiteSession") -> "AclObservation": + node_ip_to_idx = {} + for node_idx, node_cfg in enumerate(config['node_order']): + n_ref = node_cfg["node_ref"] + n_obj = session.simulation.network.nodes[session.ref_map_nodes[n_ref]] + for nic_uuid, nic_obj in n_obj.nics.items(): + node_ip_to_idx[nic_obj.ip_address] = node_idx + 2 + + router_uuid = session.ref_map_nodes[config['router_node_ref']] + return cls( + node_ip_to_id=node_ip_to_idx, + ports=session.options.ports, + protocols=session.options.protocols, + where=["network", "nodes", router_uuid]) @@ -413,6 +485,10 @@ class NullObservation(AbstractObservation): def space(self) -> spaces.Space: return spaces.Dict({}) + @classmethod + def from_config(cls, cfg:Dict) -> "NullObservation": + return cls() + class ICSObservation(NullObservation): pass @@ -463,11 +539,18 @@ class UC2BlueObservation(AbstractObservation): }) @classmethod - def from_config(cls, config:Dict, sim:Simulation): - nodes = ... - links = ... - acl = ... - ics = ... + def from_config(cls, config:Dict, sess:"PrimaiteSession"): + node_configs = config["nodes"] + nodes = [NodeObservation.from_config(n) for n in node_configs] + + link_configs = config["links"] + links = [LinkObservation.from_config(l) for l in link_configs] + + acl_config = config["acl"] + acl = AclObservation.from_config(acl_config) + + ics_config = config["ics"] + ics = ICSObservation.from_config(ics_config) new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=['network']) return new @@ -489,8 +572,11 @@ class UC2RedObservation(AbstractObservation): @classmethod def from_config(cls, config: Dict, sim:Simulation): + ... #TODO +class UC2GreenObservation(NullObservation): pass + class ObservationSpace: """ Manage the observations of an Actor. @@ -515,3 +601,14 @@ class ObservationSpace: @property def space(self) -> None: return self.obs.space + + @classmethod + def from_config(cls, config:Dict, session:"PrimaiteSession") -> "ObservationSpace": + if config['type'] == "UC2BlueObservation": + return cls(UC2BlueObservation(config['options'])) + elif config['type'] == "UC2RedObservation": + return cls(UC2RedObservation(config['options'])) + elif config['type'] == "UC2GreenObservation": + return cls(UC2GreenObservation(config["options"])) + else: + raise ValueError("Observation space type invalid") diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 46e834d6..f0ae05c6 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -23,6 +23,7 @@ from primaite.game.agent.observations import ( NullObservation, ServiceObservation, UC2BlueObservation, + UC2GreenObservation, UC2RedObservation, ) from primaite.game.agent.rewards import RewardFunction @@ -41,6 +42,10 @@ from primaite.simulator.system.services.dns_server import DNSServer from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot from primaite.simulator.system.services.service import Service +from primaite import getLogger + +_LOGGER = getLogger(__name__) + class PrimaiteSessionOptions(BaseModel): ports: List[str] @@ -55,13 +60,19 @@ class PrimaiteSession: self.episode_counter: int = 0 self.options: PrimaiteSessionOptions + self.ref_map_nodes: Dict[str, Node] = {} + self.ref_map_services: Dict[str, Service] = {} + self.ref_map_links: Dict[str, Link] = {} + def step(self): + _LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}") # currently designed with assumption that all agents act once per step in order for agent in self.agents: # 3. primaite session asks simulation to provide initial state # 4. primate session gives state to all agents # 5. primaite session asks agents to produce an action based on most recent state + _LOGGER.debug(f"Sending simulation state to agent {agent.agent_name}") sim_state = self.simulation.describe_state() # 6. each agent takes most recent state and converts it to CAOS observation @@ -75,14 +86,18 @@ class PrimaiteSession: # to discrete(40) is only necessary for purposes of RL learning, therefore that bit of # code should live inside of the GATE agent subclass) # gets action in CAOS format + _LOGGER.debug(f"Getting agent action") agent_action, action_options = agent.get_action(agent_obs, agent_reward) # 9. CAOS action is converted into request (extra information might be needed to enrich # the request, this is what the execution definition is there for) + _LOGGER.debug(f"Formatting agent action {agent_action}") # maybe too many debug log statements agent_request = agent.format_request(agent_action, action_options) # 10. primaite session receives the action from the agents and asks the simulation to apply each + _LOGGER.debug(f"Sending request to simulation: {agent_request}") self.simulation.apply_action(agent_request) + _LOGGER.debug(f"Initiating simulation step {self.step_counter}") self.simulation.apply_timestep(self.step_counter) self.step_counter += 1 @@ -96,9 +111,9 @@ class PrimaiteSession: sim = sess.simulation net = sim.network - ref_map_nodes: Dict[str, Node] = {} - ref_map_services: Dict[str, Service] = {} - ref_map_links: Dict[str, Link] = {} + sess.ref_map_nodes: Dict[str, Node] = {} + sess.ref_map_services: Dict[str, Service] = {} + sess.ref_map_links: Dict[str, Link] = {} nodes_cfg = cfg["simulation"]["network"]["nodes"] links_cfg = cfg["simulation"]["network"]["links"] @@ -304,6 +319,8 @@ class PrimaiteSession: ) elif observation_space_cfg["type"] == "UC2RedObservation": obs_space = UC2RedObservation.from_config(observation_space_cfg["options"], sim=sim) + elif observation_space_cfg["type"] == "UC2GreenObservation": + obs_space = UC2GreenObservation.from_config(observation_space_cfg.get('options',{})) else: print("observation space config not specified correctly.") obs_space = NullObservation() @@ -334,12 +351,15 @@ class PrimaiteSession: # CREATE AGENT if agent_type == "GreenWebBrowsingAgent": - new_agent = RandomAgent(action_space=action_space, observation_space=obs_space, reward_function=rew_function) + # TODO: implement non-random agents and fix this parsing + new_agent = RandomAgent(agent_name=agent_cfg['ref'], action_space=action_space, observation_space=obs_space, reward_function=rew_function) sess.agents.append(new_agent) elif agent_type == "GATERLAgent": - ... + new_agent = RandomAgent(agent_name=agent_cfg['ref'], action_space=action_space, observation_space=obs_space, reward_function=rew_function) + sess.agents.append(new_agent) elif agent_type == "RedDatabaseCorruptingAgent": - ... + new_agent = RandomAgent(agent_name=agent_cfg['ref'], action_space=action_space, observation_space=obs_space, reward_function=rew_function) + sess.agents.append(new_agent) else: print("agent type not found")