diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 7393f5a3..d8cd0099 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -31,7 +31,7 @@ game: - UDP agents: - - ref: client_1_green_user + - ref: client_2_green_user team: GREEN type: GreenWebBrowsingAgent observation_space: diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml index c1e2ea81..6aa54487 100644 --- a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml +++ b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml @@ -25,7 +25,7 @@ game: - UDP agents: - - ref: client_1_green_user + - ref: client_2_green_user team: GREEN type: GreenWebBrowsingAgent observation_space: diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 159f5bbb..146261f9 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -113,7 +113,7 @@ class PrimaiteGame: self.update_agents(sim_state) # Apply all actions to simulation as requests - self.apply_agent_actions() + agent_actions = self.apply_agent_actions() # noqa # Advance timestep self.advance_timestep() @@ -131,12 +131,15 @@ class PrimaiteGame: def apply_agent_actions(self) -> None: """Apply all actions to simulation as requests.""" + agent_actions = {} for agent in self.agents: obs = agent.observation_manager.current_observation rew = agent.reward_function.current_reward action_choice, options = agent.get_action(obs, rew) + agent_actions[agent.agent_name] = (action_choice, options) request = agent.format_request(action_choice, options) self.simulation.apply_request(request) + return agent_actions def advance_timestep(self) -> None: """Advance timestep.""" diff --git a/src/primaite/notebooks/_package_data/uc2_network.png b/src/primaite/notebooks/_package_data/uc2_network.png new file mode 100644 index 00000000..20fa43c9 Binary files /dev/null and b/src/primaite/notebooks/_package_data/uc2_network.png differ diff --git a/src/primaite/notebooks/uc2_demo.ipynb b/src/primaite/notebooks/uc2_demo.ipynb index 3950ef10..7bcfdd29 100644 --- a/src/primaite/notebooks/uc2_demo.ipynb +++ b/src/primaite/notebooks/uc2_demo.ipynb @@ -1,30 +1,333 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Data Manipulation Scenario\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Scenario\n", + "\n", + "The network consists of an office subnet and a server subnet. Clients in the office access a website which fetches data from a database.\n", + "\n", + "[](_package_data/uc2_network.png)\n", + "\n", + "_(click image to enlarge)_\n", + "\n", + "The red agent deletes the contents of the database. When this happens, the web app cannot fetch data and users navigating to the website get a 404 error.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Network\n", + "\n", + "- The web server has:\n", + " - a web service that replies to user HTTP requests\n", + " - a database client that fetches data for the web service\n", + "- The database server has:\n", + " - a POSTGRES database service\n", + " - a database file which is accessed by the database service\n", + " - FTP client used for backing up the data to the backup_server\n", + "- The backup server has:\n", + " - a copy of the database file in a known good state\n", + " - FTP server that can send the backed up file back to the database server\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Green agent\n", + "\n", + "The green agent is logged onto client 2. It sometimes uses the web browser on client 2 to navigate to `http://arcd.com/users`. The web server replies with a status code 200 if the data is available on the database or 404 if not available." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Red agent\n", + "\n", + "The red agent waits a bit then sends a DELETE query to the database from client 1. If the delete is successful, the database file is flagged as compromised to signal that data is not available." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Blue agent\n", + "\n", + "The blue agent can view the entire network, but the health statuses of components are not updated until a scan is performed. The blue agent should restore the database file from backup after it was compromised. It can also prevent further attacks by blocking client 1 from reaching the database server. This can be done by removing client 1's network connection or adding ACL rules on the router to stop the packets from arriving." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Reinforcement learning details" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Scripted agents:\n", + "### Red\n", + "The red agent sits on client 1 and uses an application called DataManipulationBot whose sole purpose is to send a DELETE query to the database.\n", + "The red agent can choose one of two action each timestep:\n", + "1. do nothing\n", + "2. execute the data manipulation application\n", + "The schedule for selecting when to execute the application is controlled by three parameters:\n", + "- start time\n", + "- frequency\n", + "- variance\n", + "Attacks start at a random timestep between (start_time - variance) and (start_time + variance). After each attack, another is attempted after a random delay between (frequency - variance) and (frequency + variance) timesteps.\n", + "\n", + "The data manipulation app itself has an element of randomness because the attack has a probability of success. The default is 0.8 to succeed with the port scan step and 0.8 to succeed with the attack itself.\n", + "Upon a successful attack, the database file becomes corrupted which incurs a negative reward for the RL defender.\n", + "\n", + "The red agent does not use information about the state of the network to decide its action.\n", + "\n", + "### Green\n", + "The green agent sits on client 2 and uses the web browser application to send requests to the web server. The schedule of the green agent is currently random, meaning it will request webpage with a 50% probability, and do nothing with a 50% probability.\n", + "\n", + "When the green agent is blocked from accessing the data through the webpage, this incurs a negative reward to the RL defender." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Observation Space\n", + "\n", + "The blue agent's observation space is structured as nested dictionary with the following information:\n", + "```\n", + "\n", + "- NODES\n", + " - \n", + " - SERVICES\n", + " - \n", + " - operating_status\n", + " - health_status\n", + " - FOLDERS\n", + " - \n", + " - health_status\n", + " - FILES\n", + " - \n", + " - health_status\n", + " - NICS\n", + " - \n", + " - nic_status\n", + " - operating_status\n", + "- LINKS\n", + " - \n", + " - PROTOCOLS\n", + " - ALL\n", + " - load\n", + "- ACL\n", + " - \n", + " - position\n", + " - permission\n", + " - source_node_id\n", + " - source_port\n", + " - dest_node_id\n", + " - dest_port\n", + " - protocol\n", + "- ICS\n", + "```\n", + "\n", + "### Mappings\n", + "\n", + "The dict keys for `node_id` are in the following order:\n", + "|node_id|node name|\n", + "|--|--|\n", + "|1|domain_controller|\n", + "|2|web_server|\n", + "|3|database_server|\n", + "|4|backup_server|\n", + "|5|security_suite|\n", + "|6|client_1|\n", + "|7|client_2|\n", + "\n", + "Service 1 on node 2 (web_server) corresponds to the Web Server service. Other services are only there for padding to ensure that each node's observation space has the same shape. They are filled with zeroes.\n", + "\n", + "Folder 1 on node 3 corresponds to the database folder. File 1 in that folder corresponds to the database storage file. Other files and folders are only there for padding to ensure that each node's observation space has the same shape. They are filled with zeroes.\n", + "\n", + "The dict keys for `link_id` are in the following order:\n", + "|link_id|endpoint_a|endpoint_b|\n", + "|--|--|--|\n", + "|1|router_1|switch_1|\n", + "|1|router_1|switch_2|\n", + "|1|switch_1|domain_controller|\n", + "|1|switch_1|web_server|\n", + "|1|switch_1|database_server|\n", + "|1|switch_1|backup_server|\n", + "|1|switch_1|security_suite|\n", + "|1|switch_2|client_1|\n", + "|1|switch_2|client_2|\n", + "|1|switch_2|security_suite|\n", + "\n", + "The ACL rules in the observation space appear in the same order that they do in the actual ACL. Though, only the first 10 rules are shown, there are default rules lower down that cannot be changed by the agent. The extra rules just allow the network to function normally, by allowing pings, ARP traffic, etc.\n", + "\n", + "Most nodes have only 1 nic, so the observation for those is placed at NIC index 1 in the observation space. Only the security suite has 2 NICs, the second NIC in the observation space is the one that connects the security suite with swtich_2.\n", + "\n", + "The meaning of the services' operating_state is:\n", + "|operating_state|label|\n", + "|--|--|\n", + "|0|UNUSED|\n", + "|1|RUNNING|\n", + "|2|STOPPED|\n", + "|3|PAUSED|\n", + "|4|DISABLED|\n", + "|5|INSTALLING|\n", + "|6|RESTARTING|\n", + "\n", + "The meaning of the services' health_state is:\n", + "|health_state|label|\n", + "|--|--|\n", + "|0|UNUSED|\n", + "|1|GOOD|\n", + "|2|PATCHING|\n", + "|3|COMPROMISED|\n", + "|4|OVERWHELMED|\n", + "\n", + "The meaning of the files' and folders' health_state is:\n", + "|health_state|label|\n", + "|--|--|\n", + "|0|UNUSED|\n", + "|1|GOOD|\n", + "|2|COMPROMISED|\n", + "|3|CORRUPT|\n", + "|4|RESTORING|\n", + "|5|REPAIRING|\n", + "\n", + "The meaning of the NICs' operating_status is:\n", + "|operating_status|label|\n", + "|--|--|\n", + "|0|UNUSED|\n", + "|1|ENABLED|\n", + "|2|DISABLED|\n", + "\n", + "Link load has the following meaning:\n", + "|load|percent utilisation|\n", + "|--|--|\n", + "|0|exactly 0%|\n", + "|1|0-11%|\n", + "|2|11-22%|\n", + "|3|22-33%|\n", + "|4|33-44%|\n", + "|5|44-55%|\n", + "|6|55-66%|\n", + "|7|66-77%|\n", + "|8|77-88%|\n", + "|9|88-99%|\n", + "|10|exactly 100%|\n", + "\n", + "ACL permission has the following meaning:\n", + "|permission|label|\n", + "|--|--|\n", + "|0|UNUSED|\n", + "|1|ALLOW|\n", + "|2|DENY|\n", + "\n", + "ACL source / destination node ids actually correspond to IP addresses (since ACLs work with IP addresses)\n", + "|source / dest node id|ip_address|label|\n", + "|--|--|--|\n", + "|0| | UNUSED|\n", + "|1| |ALL addresses|\n", + "|2| 192.168.1.10 | domain_controller|\n", + "|3| 192.168.1.12 | web_server \n", + "|4| 192.168.1.14 | database_server|\n", + "|5| 192.168.1.16 | backup_server|\n", + "|6| 192.168.1.110 | security_suite (eth-1)|\n", + "|7| 192.168.10.21 | client_1|\n", + "|8| 192.168.10.22 | client_2|\n", + "|9| 192.168.10.110| security_suite (eth-2)|\n", + "\n", + "ACL source / destination port ids have the following encoding:\n", + "|port id|port number| port use |\n", + "|--|--|--|\n", + "|0||UNUSED|\n", + "|1||ALL|\n", + "|2|219|ARP|\n", + "|3|53|DNS|\n", + "|4|80|HTTP|\n", + "|5|5432|POSTGRES_SERVER|\n", + "\n", + "ACL protocol ids have the following encoding:\n", + "|protocol id|label|\n", + "|--|--|\n", + "|0|UNUSED|\n", + "|1|ALL|\n", + "|2|ICMP|\n", + "|3|TCP|\n", + "|4|UDP|\n", + "\n", + "protocol" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Action Space\n", + "\n", + "The blue agent chooses from a list of 54 pre-defined actions. The full list is defined in the `action_map` in the config. The most important ones are explained here:\n", + "\n", + "- `0`: Do nothing\n", + "- `1`: Scan the web service - this refreshes the health status in the observation space\n", + "- `9`: Scan the database file - this refreshes the health status of the database file\n", + "- `13`: Patch the database service - This triggers the database to restore data from the backup server\n", + "- `19`: Shut down client 1\n", + "- `22`: Block outgoing traffic from client 1\n", + "- `26`: Block TCP traffic from client 1 to the database node\n", + "- `28-37`: Remove ACL rules 1-10\n", + "- `42`: Disconnect client 1 from the network\n", + "\n", + "The other actions will either have no effect or will negatively impact the network, so the blue agent should avoid taking other actions, and learn about these actions." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Reward Function\n", + "\n", + "The blue agent's reward is calculated using two measures:\n", + "1. Whether the database file is in a good state (+1 for good, -1 for corrupted, 0 for any other state)\n", + "2. Whether the green agent's most recent webpage request was successful (+1 for a `200` return code, -1 for a `404` return code and 0 otherwise).\n", + "These two components are averaged to get the final reward.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Demonstration" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, load the required modules" + ] + }, { "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/cade/repos/PrimAITE/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "2023-11-26 23:25:47,985\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n", - "2023-11-26 23:25:51,213\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n", - "2023-11-26 23:25:51,491\tWARNING __init__.py:10 -- PG has/have been moved to `rllib_contrib` and will no longer be maintained by the RLlib team. You can still use it/them normally inside RLlib util Ray 2.8, but from Ray 2.9 on, all `rllib_contrib` algorithms will no longer be part of the core repo, and will therefore have to be installed separately with pinned dependencies for e.g. ray[rllib] and other packages! See https://github.com/ray-project/ray/tree/master/rllib_contrib#rllib-contrib for more information on the RLlib contrib effort.\n" - ] - } - ], + "outputs": [], "source": [ - "from primaite.session.session import PrimaiteSession\n", - "from primaite.game.game import PrimaiteGame\n", - "from primaite.config.load import example_config_path\n", - "\n", - "from primaite.simulator.system.services.database.database_service import DatabaseService\n", - "\n", - "import yaml" + "%load_ext autoreload\n", + "%autoreload 2" ] }, { @@ -36,61 +339,181 @@ "name": "stderr", "output_type": "stream", "text": [ - "2023-11-26 23:25:51,579::ERROR::primaite.simulator.network.hardware.base::175::NIC a9:92:0a:5e:1b:e4/127.0.0.1 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,580::ERROR::primaite.simulator.network.hardware.base::175::NIC ef:03:23:af:3c:19/127.0.0.1 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,581::ERROR::primaite.simulator.network.hardware.base::175::NIC ae:cf:83:2f:94:17/127.0.0.1 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,582::ERROR::primaite.simulator.network.hardware.base::175::NIC 4c:b2:99:e2:4a:5d/127.0.0.1 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,583::ERROR::primaite.simulator.network.hardware.base::175::NIC b9:eb:f9:c2:17:2f/127.0.0.1 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,590::ERROR::primaite.simulator.network.hardware.base::175::NIC cb:df:ca:54:be:01/192.168.1.10 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,595::ERROR::primaite.simulator.network.hardware.base::175::NIC 6e:32:12:da:4d:0d/192.168.1.12 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,600::ERROR::primaite.simulator.network.hardware.base::175::NIC 58:6e:9b:a7:68:49/192.168.1.14 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,604::ERROR::primaite.simulator.network.hardware.base::175::NIC 33:db:a6:40:dd:a3/192.168.1.16 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,608::ERROR::primaite.simulator.network.hardware.base::175::NIC 72:aa:2b:c0:4c:5f/192.168.1.110 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,610::ERROR::primaite.simulator.network.hardware.base::175::NIC 11:d7:0e:90:d9:a4/192.168.10.110 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,614::ERROR::primaite.simulator.network.hardware.base::175::NIC 86:2b:a4:e5:4d:0f/192.168.10.21 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,631::ERROR::primaite.simulator.network.hardware.base::175::NIC af:ad:8f:84:f1:db/192.168.10.22 cannot be enabled as it is not connected to a Link\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "installing DNSServer on node domain_controller\n", - "installing DatabaseClient on node web_server\n", - "installing WebServer on node web_server\n", - "installing DatabaseService on node database_server\n", - "installing FTPClient on node database_server\n", - "installing FTPServer on node backup_server\n", - "installing DNSClient on node client_1\n", - "installing DNSClient on node client_2\n" + "/home/cade/repos/PrimAITE/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "2024-01-25 11:19:29,199\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n", + "2024-01-25 11:19:31,924\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n" ] } ], "source": [ + "# Imports\n", + "from primaite.config.load import example_config_path\n", + "from primaite.session.environment import PrimaiteGymEnv\n", + "from primaite.game.game import PrimaiteGame\n", + "import yaml\n", + "from pprint import pprint\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Instantiate the environment. We also disable the agent observation flattening.\n", "\n", - "with open(example_config_path(),'r') as cfgfile:\n", - " cfg = yaml.safe_load(cfgfile)\n", - "game = PrimaiteGame.from_config(cfg)\n", - "net = game.simulation.network\n", - "database_server = net.get_node_by_hostname('database_server')\n", - "web_server = net.get_node_by_hostname('web_server')\n", - "client_1 = net.get_node_by_hostname('client_1')\n", - "\n", - "db_service = database_server.software_manager.software[\"DatabaseService\"]\n", - "db_client = web_server.software_manager.software[\"DatabaseClient\"]\n", - "# db_client.run()\n", - "db_manipulation_bot = client_1.software_manager.software[\"DataManipulationBot\"]\n", - "db_manipulation_bot.port_scan_p_of_success=1.0\n", - "db_manipulation_bot.data_manipulation_p_of_success=1.0\n" + "This cell will print the observation when the network is healthy. You should be able to verify Node file and service statuses against the description above." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Resetting environment, episode 0, avg. reward: 0.0\n", + "env created successfully\n", + "{'ACL': {1: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 0,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 2: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 1,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 3: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 2,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 4: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 3,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 5: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 4,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 6: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 5,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 7: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 6,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 8: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 7,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 9: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 8,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 10: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 9,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0}},\n", + " 'ICS': 0,\n", + " 'LINKS': {1: {'PROTOCOLS': {'ALL': 1}},\n", + " 2: {'PROTOCOLS': {'ALL': 1}},\n", + " 3: {'PROTOCOLS': {'ALL': 1}},\n", + " 4: {'PROTOCOLS': {'ALL': 1}},\n", + " 5: {'PROTOCOLS': {'ALL': 1}},\n", + " 6: {'PROTOCOLS': {'ALL': 1}},\n", + " 7: {'PROTOCOLS': {'ALL': 1}},\n", + " 8: {'PROTOCOLS': {'ALL': 1}},\n", + " 9: {'PROTOCOLS': {'ALL': 1}},\n", + " 10: {'PROTOCOLS': {'ALL': 1}}},\n", + " 'NODES': {1: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 2: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 3: {'FOLDERS': {1: {'FILES': {1: {'health_status': 1}},\n", + " 'health_status': 1}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 4: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 5: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 6: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 7: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1}}}\n" + ] + } + ], "source": [ - "db_client.run()" + "# create the env\n", + "with open(example_config_path(), 'r') as f:\n", + " cfg = yaml.safe_load(f)\n", + "game = PrimaiteGame.from_config(cfg)\n", + "env = PrimaiteGymEnv(game = game)\n", + "# Don't flatten obs as we are not training an agent and we wish to see the dict-formatted observations\n", + "env.agent.flatten_obs = False\n", + "obs, info = env.reset()\n", + "print('env created successfully')\n", + "pprint(obs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The red agent will start attacking at some point between step 20 and 30.\n", + "\n", + "The red agent has a random chance of failing its attack, so you may need run the following cell multiple times until the reward goes from 1.0 to -1.0." ] }, { @@ -99,18 +522,53 @@ "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "step: 1, Red action: DONOTHING, Blue reward:1.0\n", + "step: 2, Red action: DONOTHING, Blue reward:1.0\n", + "step: 3, Red action: DONOTHING, Blue reward:1.0\n", + "step: 4, Red action: DONOTHING, Blue reward:1.0\n", + "step: 5, Red action: DONOTHING, Blue reward:1.0\n", + "step: 6, Red action: DONOTHING, Blue reward:1.0\n", + "step: 7, Red action: DONOTHING, Blue reward:1.0\n", + "step: 8, Red action: DONOTHING, Blue reward:1.0\n", + "step: 9, Red action: DONOTHING, Blue reward:1.0\n", + "step: 10, Red action: DONOTHING, Blue reward:1.0\n", + "step: 11, Red action: DONOTHING, Blue reward:1.0\n", + "step: 12, Red action: DONOTHING, Blue reward:1.0\n", + "step: 13, Red action: DONOTHING, Blue reward:1.0\n", + "step: 14, Red action: DONOTHING, Blue reward:1.0\n", + "step: 15, Red action: DONOTHING, Blue reward:1.0\n", + "step: 16, Red action: DONOTHING, Blue reward:1.0\n", + "step: 17, Red action: DONOTHING, Blue reward:1.0\n", + "step: 18, Red action: DONOTHING, Blue reward:1.0\n", + "step: 19, Red action: DONOTHING, Blue reward:1.0\n", + "step: 20, Red action: DONOTHING, Blue reward:1.0\n", + "step: 21, Red action: DONOTHING, Blue reward:1.0\n", + "step: 22, Red action: DONOTHING, Blue reward:1.0\n", + "step: 23, Red action: DONOTHING, Blue reward:1.0\n", + "step: 24, Red action: DONOTHING, Blue reward:1.0\n", + "step: 25, Red action: DONOTHING, Blue reward:1.0\n", + "step: 26, Red action: DONOTHING, Blue reward:1.0\n", + "step: 27, Red action: NODE_APPLICATION_EXECUTE, Blue reward:0.0\n", + "step: 28, Red action: DONOTHING, Blue reward:-1.0\n", + "step: 29, Red action: DONOTHING, Blue reward:-1.0\n", + "step: 30, Red action: DONOTHING, Blue reward:-1.0\n" + ] } ], "source": [ - "db_service.backup_database()" + "for step in range(30):\n", + " obs, reward, terminated, truncated, info = env.step(0)\n", + " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['client_1_data_manipulation_red_bot'][0]}, Blue reward:{reward}\" )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now the reward is -1, let's have a look at blue agent's observation." ] }, { @@ -119,27 +577,110 @@ "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "{1: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 2: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 3: {'FOLDERS': {1: {'FILES': {1: {'health_status': 1}}, 'health_status': 1}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 4: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 5: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 6: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 7: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1}}\n" + ] } ], "source": [ - "db_client.query(\"SELECT\")" + "pprint(obs['NODES'])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The true statuses of the database file and webapp are not updated. The blue agent needs to perform a scan to see that they have degraded." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{1: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 2: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 3, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 3: {'FOLDERS': {1: {'FILES': {1: {'health_status': 2}}, 'health_status': 1}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 4: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 5: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 6: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 7: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1}}\n" + ] + } + ], "source": [ - "db_manipulation_bot.run()" + "obs, reward, terminated, truncated, info = env.step(9) # scan database file\n", + "obs, reward, terminated, truncated, info = env.step(1) # scan webapp service\n", + "pprint(obs['NODES'])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now service 1 on node 2 has `health_status = 3`, indicating that the webapp is compromised.\n", + "File 1 in folder 1 on node 3 has `health_status = 2`, indicating that the database file is compromised." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The blue agent can now patch the database to restore the file to a good health status." ] }, { @@ -148,130 +689,221 @@ "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "db_client.query(\"SELECT\")" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "db_service.restore_backup()" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "db_client.query(\"SELECT\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "db_manipulation_bot.run()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "client_1.ping(database_server.ethernet_port[1].ip_address)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "from pydantic import validate_call, BaseModel" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "class A(BaseModel):\n", - " x:int\n", - "\n", - " @validate_call\n", - " def increase_x(self, by:int) -> None:\n", - " self.x += 1" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "my_a = A(x=3)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "ename": "ValidationError", - "evalue": "1 validation error for increase_x\n0\n Input should be a valid integer, got a number with a fractional part [type=int_from_float, input_value=3.2, input_type=float]\n For further information visit https://errors.pydantic.dev/2.1/v/int_from_float", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m/home/cade/repos/PrimAITE/src/primaite/notebooks/uc2_demo.ipynb Cell 15\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> 1\u001b[0m my_a\u001b[39m.\u001b[39;49mincrease_x(\u001b[39m3.2\u001b[39;49m)\n", - "File \u001b[0;32m~/repos/PrimAITE/venv/lib/python3.10/site-packages/pydantic/_internal/_validate_call.py:91\u001b[0m, in \u001b[0;36mValidateCallWrapper.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39margs: Any, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs: Any) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Any:\n\u001b[0;32m---> 91\u001b[0m res \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__pydantic_validator__\u001b[39m.\u001b[39;49mvalidate_python(pydantic_core\u001b[39m.\u001b[39;49mArgsKwargs(args, kwargs))\n\u001b[1;32m 92\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__return_pydantic_validator__:\n\u001b[1;32m 93\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__return_pydantic_validator__\u001b[39m.\u001b[39mvalidate_python(res)\n", - "\u001b[0;31mValidationError\u001b[0m: 1 validation error for increase_x\n0\n Input should be a valid integer, got a number with a fractional part [type=int_from_float, input_value=3.2, input_type=float]\n For further information visit https://errors.pydantic.dev/2.1/v/int_from_float" + "name": "stdout", + "output_type": "stream", + "text": [ + "step: 33\n", + "Red action: DONOTHING\n", + "Green action: DONOTHING\n", + "Blue reward:-1.0\n" ] } ], "source": [ - "my_a.increase_x(3.2)" + "obs, reward, terminated, truncated, info = env.step(13) # patch the database\n", + "print(f\"step: {env.game.step_counter}\")\n", + "print(f\"Red action: {info['agent_actions']['client_1_data_manipulation_red_bot'][0]}\" )\n", + "print(f\"Green action: {info['agent_actions']['client_2_green_user'][0]}\" )\n", + "print(f\"Blue reward:{reward}\" )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The patching takes two steps, so the reward hasn't changed yet. Let's do nothing for another timestep, the reward should improve.\n", + "\n", + "The reward will be 0 as soon as the file finishes restoring. Then, the reward will increase to 1 when the green agent makes a request. (Because the webapp access part of the reward does not update until a successful request is made.)\n", + "\n", + "Run the following cell until the green action is `NODE_APPLICATION_EXECUTE`, then the reward should become 1. If you run it enough times, another red attack will happen and the reward will drop again." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step: 44\n", + "Red action: DONOTHING\n", + "Green action: NODE_APPLICATION_EXECUTE\n", + "Blue reward:-1.0\n" + ] + } + ], + "source": [ + "obs, reward, terminated, truncated, info = env.step(0) # patch the database\n", + "print(f\"step: {env.game.step_counter}\")\n", + "print(f\"Red action: {info['agent_actions']['client_1_data_manipulation_red_bot'][0]}\" )\n", + "print(f\"Green action: {info['agent_actions']['client_2_green_user'][0]}\" )\n", + "print(f\"Blue reward:{reward}\" )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The blue agent can prevent attacks by implementing an ACL rule to stop client_1 from sending POSTGRES traffic to the database. (Let's also patch the database file to get the reward back up.)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step: 107, Red action: DONOTHING, Blue reward:1.0\n", + "step: 108, Red action: DONOTHING, Blue reward:1.0\n", + "step: 109, Red action: DONOTHING, Blue reward:1.0\n", + "step: 110, Red action: DONOTHING, Blue reward:1.0\n", + "step: 111, Red action: DONOTHING, Blue reward:1.0\n", + "step: 112, Red action: DONOTHING, Blue reward:1.0\n", + "step: 113, Red action: DONOTHING, Blue reward:1.0\n", + "step: 114, Red action: NODE_APPLICATION_EXECUTE, Blue reward:1.0\n", + "step: 115, Red action: DONOTHING, Blue reward:1.0\n", + "step: 116, Red action: DONOTHING, Blue reward:1.0\n", + "step: 117, Red action: DONOTHING, Blue reward:1.0\n", + "step: 118, Red action: DONOTHING, Blue reward:1.0\n", + "step: 119, Red action: DONOTHING, Blue reward:1.0\n", + "step: 120, Red action: DONOTHING, Blue reward:1.0\n", + "step: 121, Red action: DONOTHING, Blue reward:1.0\n", + "step: 122, Red action: DONOTHING, Blue reward:1.0\n", + "step: 123, Red action: DONOTHING, Blue reward:1.0\n", + "step: 124, Red action: DONOTHING, Blue reward:1.0\n", + "step: 125, Red action: DONOTHING, Blue reward:1.0\n", + "step: 126, Red action: DONOTHING, Blue reward:1.0\n", + "step: 127, Red action: DONOTHING, Blue reward:1.0\n", + "step: 128, Red action: DONOTHING, Blue reward:1.0\n", + "step: 129, Red action: DONOTHING, Blue reward:1.0\n", + "step: 130, Red action: DONOTHING, Blue reward:1.0\n", + "step: 131, Red action: DONOTHING, Blue reward:1.0\n", + "step: 132, Red action: DONOTHING, Blue reward:1.0\n", + "step: 133, Red action: DONOTHING, Blue reward:1.0\n", + "step: 134, Red action: NODE_APPLICATION_EXECUTE, Blue reward:1.0\n", + "step: 135, Red action: DONOTHING, Blue reward:1.0\n", + "step: 136, Red action: DONOTHING, Blue reward:1.0\n" + ] + } + ], + "source": [ + "env.step(13) # Patch the database\n", + "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['client_1_data_manipulation_red_bot'][0]}, Blue reward:{reward}\" )\n", + "\n", + "env.step(26) # Block client 1\n", + "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['client_1_data_manipulation_red_bot'][0]}, Blue reward:{reward}\" )\n", + "\n", + "for step in range(30):\n", + " obs, reward, terminated, truncated, info = env.step(0) # do nothing\n", + " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['client_1_data_manipulation_red_bot'][0]}, Blue reward:{reward}\" )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, even though the red agent executes an attack, the reward stays at 1.0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's also have a look at the ACL observation to verify our new ACL rule at position 5." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{1: {'position': 0,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 2: {'position': 1,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 3: {'position': 2,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 4: {'position': 3,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 5: {'position': 4,\n", + " 'permission': 2,\n", + " 'source_node_id': 7,\n", + " 'source_port': 1,\n", + " 'dest_node_id': 4,\n", + " 'dest_port': 1,\n", + " 'protocol': 3},\n", + " 6: {'position': 5,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 7: {'position': 6,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 8: {'position': 7,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 9: {'position': 8,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 10: {'position': 9,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0}}" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "obs['ACL']" ] }, { diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index 6701f183..a3831bc1 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -29,7 +29,7 @@ class PrimaiteGymEnv(gymnasium.Env): # make ProxyAgent store the action chosen my the RL policy self.agent.store_action(action) # apply_agent_actions accesses the action we just stored - self.game.apply_agent_actions() + agent_actions = self.game.apply_agent_actions() self.game.advance_timestep() state = self.game.get_sim_state() @@ -39,7 +39,7 @@ class PrimaiteGymEnv(gymnasium.Env): reward = self.agent.reward_function.current_reward terminated = False truncated = self.game.calculate_truncated() - info = {} + info = {"agent_actions": agent_actions} # tell us what all the agents did for convenience. if self.game.save_step_metadata: self._write_step_metadata_json(action, state, reward) return next_obs, reward, terminated, truncated, info @@ -172,7 +172,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): # 1. Perform actions for agent_name, action in actions.items(): self.agents[agent_name].store_action(action) - self.game.apply_agent_actions() + agent_actions = self.game.apply_agent_actions() # 2. Advance timestep self.game.advance_timestep() @@ -186,7 +186,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()} terminateds = {name: False for name, _ in self.agents.items()} truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()} - infos = {} + infos = {"agent_actions": agent_actions} terminateds["__all__"] = len(self.terminateds) == len(self.agents) truncateds["__all__"] = self.game.calculate_truncated() if self.game.save_step_metadata: diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml index 9070f246..e5458670 100644 --- a/tests/assets/configs/bad_primaite_session.yaml +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -19,7 +19,7 @@ game: - UDP agents: - - ref: client_1_green_user + - ref: client_2_green_user team: GREEN type: GreenWebBrowsingAgent observation_space: diff --git a/tests/assets/configs/eval_only_primaite_session.yaml b/tests/assets/configs/eval_only_primaite_session.yaml index e67f6606..767279ce 100644 --- a/tests/assets/configs/eval_only_primaite_session.yaml +++ b/tests/assets/configs/eval_only_primaite_session.yaml @@ -23,7 +23,7 @@ game: - UDP agents: - - ref: client_1_green_user + - ref: client_2_green_user team: GREEN type: GreenWebBrowsingAgent observation_space: diff --git a/tests/assets/configs/multi_agent_session.yaml b/tests/assets/configs/multi_agent_session.yaml index 220ca21e..6290fa53 100644 --- a/tests/assets/configs/multi_agent_session.yaml +++ b/tests/assets/configs/multi_agent_session.yaml @@ -29,7 +29,7 @@ game: - UDP agents: - - ref: client_1_green_user + - ref: client_2_green_user team: GREEN type: GreenWebBrowsingAgent observation_space: diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index d7e94cb6..89b88475 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -27,7 +27,7 @@ game: - UDP agents: - - ref: client_1_green_user + - ref: client_2_green_user team: GREEN type: GreenWebBrowsingAgent observation_space: diff --git a/tests/assets/configs/train_only_primaite_session.yaml b/tests/assets/configs/train_only_primaite_session.yaml index b89349c0..b9fa1216 100644 --- a/tests/assets/configs/train_only_primaite_session.yaml +++ b/tests/assets/configs/train_only_primaite_session.yaml @@ -23,7 +23,7 @@ game: - UDP agents: - - ref: client_1_green_user + - ref: client_2_green_user team: GREEN type: GreenWebBrowsingAgent observation_space: