Merged PR 296: Add a notebook to customise red agent behaviour
## Summary just run the notebook and see if the explanation makes sense. I also renamed some stuff to make it more user friendly Related work items: #2343
This commit is contained in:
@@ -109,6 +109,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE!
|
||||
source/game_layer
|
||||
source/config
|
||||
source/environment
|
||||
source/customising_scenarios
|
||||
|
||||
.. toctree::
|
||||
:caption: Developer information:
|
||||
|
||||
4
docs/source/customising_scenarios.rst
Normal file
4
docs/source/customising_scenarios.rst
Normal file
@@ -0,0 +1,4 @@
|
||||
Customising Agents
|
||||
******************
|
||||
|
||||
For an example of how to customise red agent behaviour in the Data Manipulation scenario, please refer to the notebook ``Data-Manipulation-Customising-Red-Agent.ipynb``.
|
||||
@@ -127,10 +127,10 @@ def session(
|
||||
:param config: The path to the config file. Optional, if None, the example config will be used.
|
||||
:type config: Optional[str]
|
||||
"""
|
||||
from primaite.config.load import example_config_path
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.main import run
|
||||
|
||||
if not config:
|
||||
config = example_config_path()
|
||||
config = data_manipulation_config_path()
|
||||
print(config)
|
||||
run(config_path=config, agent_load_path=agent_load_file)
|
||||
|
||||
@@ -134,9 +134,6 @@ agents:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
- type: NODE_FILE_DELETE
|
||||
- type: NODE_FILE_CORRUPT
|
||||
- type: NODE_OS_SCAN
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_1
|
||||
|
||||
@@ -30,7 +30,7 @@ def load(file_path: Union[str, Path]) -> Dict:
|
||||
return config
|
||||
|
||||
|
||||
def example_config_path() -> Path:
|
||||
def data_manipulation_config_path() -> Path:
|
||||
"""
|
||||
Get the path to the example config.
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.config.load import example_config_path, load
|
||||
from primaite.config.load import data_manipulation_config_path, load
|
||||
from primaite.session.session import PrimaiteSession
|
||||
|
||||
# from primaite.primaite_session import PrimaiteSession
|
||||
@@ -42,6 +42,6 @@ if __name__ == "__main__":
|
||||
|
||||
args = parser.parse_args()
|
||||
if not args.config:
|
||||
args.config = example_config_path()
|
||||
args.config = data_manipulation_config_path()
|
||||
|
||||
run(args.config)
|
||||
|
||||
@@ -0,0 +1,446 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Customising red agents\n",
|
||||
"\n",
|
||||
"This notebook will go over some examples of how red agent behaviour can be varied by changing its configuration parameters.\n",
|
||||
"\n",
|
||||
"First, let's load the standard Data Manipulation config file, and see what the red agent does.\n",
|
||||
"\n",
|
||||
"*(For a full explanation of the Data Manipulation scenario, check out the notebook `Data-Manipulation-E2E-Demonstration.ipynb`)*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Imports\n",
|
||||
"\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"from primaite.session.environment import PrimaiteGymEnv\n",
|
||||
"import yaml\n",
|
||||
"from pprint import pprint"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def make_cfg_have_flat_obs(cfg):\n",
|
||||
" for agent in cfg['agents']:\n",
|
||||
" if agent['type'] == \"ProxyAgent\":\n",
|
||||
" agent['agent_settings']['flatten_obs'] = False"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(data_manipulation_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
" make_cfg_have_flat_obs(cfg)\n",
|
||||
"\n",
|
||||
"env = PrimaiteGymEnv(game_config = cfg)\n",
|
||||
"obs, info = env.reset()\n",
|
||||
"print('env created successfully')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def friendly_output_red_action(info):\n",
|
||||
" # parse the info dict form step output and write out what the red agent is doing\n",
|
||||
" red_info = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_action = red_info[0]\n",
|
||||
" if red_action == 'DONOTHING':\n",
|
||||
" red_str = 'DO NOTHING'\n",
|
||||
" elif red_action == 'NODE_APPLICATION_EXECUTE':\n",
|
||||
" client = \"client 1\" if red_info[1]['node_id'] == 0 else \"client 2\"\n",
|
||||
" red_str = f\"ATTACK from {client}\"\n",
|
||||
" return red_str"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"By default, the red agent can start on client 1 or client 2. It starts its attack on a random step between 20 and 30, and it repeats its attack every 15-25 steps.\n",
|
||||
"\n",
|
||||
"It also has a 20% chance to fail to perform the port scan, and a 20% chance to fail launching the SQL attack. However it will continue where it left off after a failed step. I.e. if lucky, it can perform the port scan and SQL attack on the first try. If the port scan works, but the sql attack fails the first time it tries to attack, the next time it will not need to port scan again, it can go straight to trying to use SQL attack again."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for step in range(35):\n",
|
||||
" step_num = env.game.step_counter\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(0)\n",
|
||||
" red = friendly_output_red_action(info)\n",
|
||||
" print(f\"step: {step_num:3}, Red action: {friendly_output_red_action(info)}, Blue reward:{reward:.2f}\" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Since the agent does nothing most of the time, let's only print the steps where it performs an attack."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.reset()\n",
|
||||
"for step in range(100):\n",
|
||||
" step_num = env.game.step_counter\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(0)\n",
|
||||
" red = friendly_output_red_action(info)\n",
|
||||
" if red.startswith(\"ATTACK\"):\n",
|
||||
" print(f\"step: {step_num:3}, Red action: {friendly_output_red_action(info)}, Blue reward:{reward:.2f}\" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Red Configuration\n",
|
||||
"\n",
|
||||
"There are two important parts of the YAML config for varying red agent behaviour.\n",
|
||||
"\n",
|
||||
"### Red agent settings\n",
|
||||
"Here is an annotated config for the red agent in the data manipulation scenario.\n",
|
||||
"```yaml\n",
|
||||
" - ref: data_manipulation_attacker # name of agent\n",
|
||||
" team: RED # not used, just for human reference\n",
|
||||
" type: RedDatabaseCorruptingAgent # type of agent - this lets primaite know which agent class to use\n",
|
||||
"\n",
|
||||
" # Since the agent does not need to react to what is happening in the environment, the observation space is empty.\n",
|
||||
" observation_space:\n",
|
||||
" type: UC2RedObservation\n",
|
||||
" options:\n",
|
||||
" nodes: {}\n",
|
||||
"\n",
|
||||
" action_space:\n",
|
||||
"\n",
|
||||
" # The agent has two action choices, either do nothing, or execute a pre-scripted attack by using \n",
|
||||
" action_list:\n",
|
||||
" - type: DONOTHING\n",
|
||||
" - type: NODE_APPLICATION_EXECUTE\n",
|
||||
"\n",
|
||||
" # The agent has access to the DataManipulationBoth on clients 1 and 2.\n",
|
||||
" options:\n",
|
||||
" nodes:\n",
|
||||
" - node_name: client_1 # The network should have a node called client_1\n",
|
||||
" applications:\n",
|
||||
" - application_name: DataManipulationBot # The node client_1 should have DataManipulationBot configured on it\n",
|
||||
" - node_name: client_2 # The network should have a node called client_2\n",
|
||||
" applications:\n",
|
||||
" - application_name: DataManipulationBot # The node client_2 should have DataManipulationBot configured on it\n",
|
||||
"\n",
|
||||
" # not important\n",
|
||||
" max_folders_per_node: 1\n",
|
||||
" max_files_per_folder: 1\n",
|
||||
" max_services_per_node: 1\n",
|
||||
"\n",
|
||||
" # red agent does not need a reward function\n",
|
||||
" reward_function:\n",
|
||||
" reward_components:\n",
|
||||
" - type: DUMMY\n",
|
||||
"\n",
|
||||
" # These actions are passed to the RedDatabaseCorruptingAgent init method, they dictate the schedule of attacks\n",
|
||||
" agent_settings:\n",
|
||||
" start_settings:\n",
|
||||
" start_step: 25 # first attack at step 25\n",
|
||||
" frequency: 20 # attacks will happen every 20 steps (on average)\n",
|
||||
" variance: 5 # the timing of attacks will vary by up to 5 steps earlier or later\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"### Malicious application settings\n",
|
||||
"The red agent uses an application called `DataManipulationBot` which leverages a node's `DatabaseClient` to send a malicious SQL query to the database server. Here's an annotated example of how this is configured in the yaml *(with impertinent config items omitted)*:\n",
|
||||
"```yaml\n",
|
||||
"simulation:\n",
|
||||
" network:\n",
|
||||
" nodes:\n",
|
||||
" - ref: client_1\n",
|
||||
" hostname: client_1\n",
|
||||
" type: computer\n",
|
||||
" ip_address: 192.168.10.21\n",
|
||||
" subnet_mask: 255.255.255.0\n",
|
||||
" default_gateway: 192.168.10.1\n",
|
||||
" \n",
|
||||
" # \n",
|
||||
" applications:\n",
|
||||
" - ref: data_manipulation_bot\n",
|
||||
" type: DataManipulationBot\n",
|
||||
" options:\n",
|
||||
" port_scan_p_of_success: 0.8 # Probability that port scan is successful\n",
|
||||
" data_manipulation_p_of_success: 0.8 # Probability that SQL attack is successful\n",
|
||||
" payload: \"DELETE\" # The SQL query which causes the attack (this has to be DELETE)\n",
|
||||
" server_ip: 192.168.1.14 # IP address of server hosting the database\n",
|
||||
" - ref: client_1_database_client\n",
|
||||
" type: DatabaseClient # Database client must be installed in order for DataManipulationBot to function\n",
|
||||
" options:\n",
|
||||
" db_server_ip: 192.168.1.14 # IP address of server hosting the database\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Editing red agent settings\n",
|
||||
"\n",
|
||||
"### Removing randomness from attack timing\n",
|
||||
"\n",
|
||||
"We can make the attacks happen at completely predictable intervals if we edit the red agent's settings to set variance to 0."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"change = yaml.safe_load(\"\"\"\n",
|
||||
"start_settings:\n",
|
||||
" start_step: 25\n",
|
||||
" frequency: 20\n",
|
||||
" variance: 0\n",
|
||||
"\"\"\")\n",
|
||||
"\n",
|
||||
"with open(data_manipulation_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
" for agent in cfg['agents']:\n",
|
||||
" if agent['ref'] == \"data_manipulation_attacker\":\n",
|
||||
" agent['agent_settings'] = change\n",
|
||||
"\n",
|
||||
"env = PrimaiteGymEnv(game_config = cfg)\n",
|
||||
"env.reset()\n",
|
||||
"for step in range(100):\n",
|
||||
" step_num = env.game.step_counter\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(0)\n",
|
||||
" red = friendly_output_red_action(info)\n",
|
||||
" if red.startswith(\"ATTACK\"):\n",
|
||||
" print(f\"step: {step_num:3}, Red action: {friendly_output_red_action(info)}, Blue reward:{reward:.2f}\" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Making the start node always the same\n",
|
||||
"\n",
|
||||
"Normally, the agent randomly chooses between the nodes in its action space to send attacks from:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Open the config without changing anything\n",
|
||||
"with open(data_manipulation_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
"\n",
|
||||
"env = PrimaiteGymEnv(game_config = cfg)\n",
|
||||
"env.reset()\n",
|
||||
"for ep in range(12):\n",
|
||||
" env.reset()\n",
|
||||
" for step in range(31):\n",
|
||||
" step_num = env.game.step_counter\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(0)\n",
|
||||
" red = friendly_output_red_action(info)\n",
|
||||
" if red.startswith(\"ATTACK\"):\n",
|
||||
" print(f\"Episode: {ep:2}, step: {step_num:3}, Red action: {friendly_output_red_action(info)}\" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can make the agent always start on a node of our choice letting that be the only node in the agent's action space."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"change = yaml.safe_load(\"\"\"\n",
|
||||
"action_space:\n",
|
||||
" action_list:\n",
|
||||
" - type: DONOTHING\n",
|
||||
" - type: NODE_APPLICATION_EXECUTE\n",
|
||||
" options:\n",
|
||||
" nodes:\n",
|
||||
" - node_name: client_1\n",
|
||||
" applications:\n",
|
||||
" - application_name: DataManipulationBot\n",
|
||||
" max_folders_per_node: 1\n",
|
||||
" max_files_per_folder: 1\n",
|
||||
" max_services_per_node: 1\n",
|
||||
"\"\"\")\n",
|
||||
"\n",
|
||||
"with open(data_manipulation_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
" for agent in cfg['agents']:\n",
|
||||
" if agent['ref'] == \"data_manipulation_attacker\":\n",
|
||||
" agent.update(change)\n",
|
||||
"\n",
|
||||
"env = PrimaiteGymEnv(game_config = cfg)\n",
|
||||
"env.reset()\n",
|
||||
"for ep in range(12):\n",
|
||||
" env.reset()\n",
|
||||
" for step in range(31):\n",
|
||||
" step_num = env.game.step_counter\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(0)\n",
|
||||
" red = friendly_output_red_action(info)\n",
|
||||
" if red.startswith(\"ATTACK\"):\n",
|
||||
" print(f\"Episode: {ep:2}, step: {step_num:3}, Red action: {friendly_output_red_action(info)}\" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Make the attack less likely to succeed.\n",
|
||||
"\n",
|
||||
"We can change the success probabilities within the data manipulation bot application. When the attack succeeds, the reward goes down.\n",
|
||||
"\n",
|
||||
"Setting the probabilities to 1.0 means the attack always succeeds - the reward will always drop\n",
|
||||
"\n",
|
||||
"Setting the probabilities to 0.0 means the attack always fails - the reward will never drop."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Make attack always succeed.\n",
|
||||
"change = yaml.safe_load(\"\"\"\n",
|
||||
" applications:\n",
|
||||
" - ref: data_manipulation_bot\n",
|
||||
" type: DataManipulationBot\n",
|
||||
" options:\n",
|
||||
" port_scan_p_of_success: 1.0\n",
|
||||
" data_manipulation_p_of_success: 1.0\n",
|
||||
" payload: \"DELETE\"\n",
|
||||
" server_ip: 192.168.1.14\n",
|
||||
" - ref: client_1_web_browser\n",
|
||||
" type: WebBrowser\n",
|
||||
" options:\n",
|
||||
" target_url: http://arcd.com/users/\n",
|
||||
" - ref: client_1_database_client\n",
|
||||
" type: DatabaseClient\n",
|
||||
" options:\n",
|
||||
" db_server_ip: 192.168.1.14\n",
|
||||
"\"\"\")\n",
|
||||
"\n",
|
||||
"with open(data_manipulation_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
" cfg['simulation']['network']\n",
|
||||
" for node in cfg['simulation']['network']['nodes']:\n",
|
||||
" if node['ref'] in ['client_1', 'client_2']:\n",
|
||||
" node['applications'] = change['applications']\n",
|
||||
"\n",
|
||||
"env = PrimaiteGymEnv(game_config = cfg)\n",
|
||||
"env.reset()\n",
|
||||
"for ep in range(5):\n",
|
||||
" env.reset()\n",
|
||||
" for step in range(36):\n",
|
||||
" step_num = env.game.step_counter\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(0)\n",
|
||||
" red = friendly_output_red_action(info)\n",
|
||||
" if step_num == 35:\n",
|
||||
" print(f\"Episode: {ep:2}, step: {step_num:3}, Reward: {reward:.2f}\" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Make attack always fail.\n",
|
||||
"change = yaml.safe_load(\"\"\"\n",
|
||||
" applications:\n",
|
||||
" - ref: data_manipulation_bot\n",
|
||||
" type: DataManipulationBot\n",
|
||||
" options:\n",
|
||||
" port_scan_p_of_success: 0.0\n",
|
||||
" data_manipulation_p_of_success: 0.0\n",
|
||||
" payload: \"DELETE\"\n",
|
||||
" server_ip: 192.168.1.14\n",
|
||||
" - ref: client_1_web_browser\n",
|
||||
" type: WebBrowser\n",
|
||||
" options:\n",
|
||||
" target_url: http://arcd.com/users/\n",
|
||||
" - ref: client_1_database_client\n",
|
||||
" type: DatabaseClient\n",
|
||||
" options:\n",
|
||||
" db_server_ip: 192.168.1.14\n",
|
||||
"\"\"\")\n",
|
||||
"\n",
|
||||
"with open(data_manipulation_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
" cfg['simulation']['network']\n",
|
||||
" for node in cfg['simulation']['network']['nodes']:\n",
|
||||
" if node['ref'] in ['client_1', 'client_2']:\n",
|
||||
" node['applications'] = change['applications']\n",
|
||||
"\n",
|
||||
"env = PrimaiteGymEnv(game_config = cfg)\n",
|
||||
"env.reset()\n",
|
||||
"for ep in range(5):\n",
|
||||
" env.reset()\n",
|
||||
" for step in range(36):\n",
|
||||
" step_num = env.game.step_counter\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(0)\n",
|
||||
" red = friendly_output_red_action(info)\n",
|
||||
" if step_num == 35:\n",
|
||||
" print(f\"Episode: {ep:2}, step: {step_num:3}, Reward: {reward:.2f}\" )"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -371,7 +371,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Imports\n",
|
||||
"from primaite.config.load import example_config_path\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"from primaite.session.environment import PrimaiteGymEnv\n",
|
||||
"from primaite.game.game import PrimaiteGame\n",
|
||||
"import yaml\n",
|
||||
@@ -394,7 +394,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# create the env\n",
|
||||
"with open(example_config_path(), 'r') as f:\n",
|
||||
"with open(data_manipulation_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
" # set success probability to 1.0 to avoid rerunning cells.\n",
|
||||
" cfg['simulation']['network']['nodes'][8]['applications'][0]['options']['data_manipulation_p_of_success'] = 1.0\n",
|
||||
@@ -16,7 +16,7 @@
|
||||
"source": [
|
||||
"from primaite.game.game import PrimaiteGame\n",
|
||||
"import yaml\n",
|
||||
"from primaite.config.load import example_config_path\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"\n",
|
||||
"from primaite.session.environment import PrimaiteRayEnv\n",
|
||||
"from ray.rllib.algorithms import ppo\n",
|
||||
@@ -26,7 +26,7 @@
|
||||
"\n",
|
||||
"# If you get an error saying this config file doesn't exist, you may need to run `primaite setup` in your command line\n",
|
||||
"# to copy the files to your user data path.\n",
|
||||
"with open(example_config_path(), 'r') as f:\n",
|
||||
"with open(data_manipulation_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
"\n",
|
||||
"ray.init(local_mode=True)\n"
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.config.load import example_config_path"
|
||||
"from primaite.config.load import data_manipulation_config_path"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -26,7 +26,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(example_config_path(), 'r') as f:\n",
|
||||
"with open(data_manipulation_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -6,10 +6,13 @@ import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class PrimaiteGymEnv(gymnasium.Env):
|
||||
"""
|
||||
@@ -75,7 +78,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
|
||||
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
|
||||
"""Reset the environment."""
|
||||
print(
|
||||
_LOGGER.info(
|
||||
f"Resetting environment, episode {self.episode_counter}, "
|
||||
f"avg. reward: {self.agent.reward_function.total_reward}"
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ import yaml
|
||||
from ray import air, tune
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
|
||||
from primaite.config.load import example_config_path
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteRayMARLEnv
|
||||
|
||||
@@ -13,7 +13,7 @@ from primaite.session.environment import PrimaiteRayMARLEnv
|
||||
def test_rllib_multi_agent_compatibility():
|
||||
"""Test that the PrimaiteRayEnv class can be used with a multi agent RLLIB system."""
|
||||
|
||||
with open(example_config_path(), "r") as f:
|
||||
with open(data_manipulation_config_path(), "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
game = PrimaiteGame.from_config(cfg)
|
||||
|
||||
@@ -6,7 +6,7 @@ import ray
|
||||
import yaml
|
||||
from ray.rllib.algorithms import ppo
|
||||
|
||||
from primaite.config.load import example_config_path
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteRayEnv
|
||||
|
||||
@@ -14,7 +14,7 @@ from primaite.session.environment import PrimaiteRayEnv
|
||||
@pytest.mark.skip(reason="Slow, reenable later")
|
||||
def test_rllib_single_agent_compatibility():
|
||||
"""Test that the PrimaiteRayEnv class can be used with a single agent RLLIB system."""
|
||||
with open(example_config_path(), "r") as f:
|
||||
with open(data_manipulation_config_path(), "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
game = PrimaiteGame.from_config(cfg)
|
||||
|
||||
@@ -6,14 +6,14 @@ import pytest
|
||||
import yaml
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
from primaite.config.load import example_config_path
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
|
||||
|
||||
def test_sb3_compatibility():
|
||||
"""Test that the Gymnasium environment can be used with an SB3 agent."""
|
||||
with open(example_config_path(), "r") as f:
|
||||
with open(data_manipulation_config_path(), "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
gym = PrimaiteGymEnv(game_config=cfg)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from primaite.config.load import example_config_path
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
@@ -7,7 +7,7 @@ from tests.integration_tests.configuration_file_parsing import BASIC_CONFIG, DMZ
|
||||
|
||||
def test_example_config():
|
||||
"""Test that the example config can be parsed properly."""
|
||||
game = load_config(example_config_path())
|
||||
game = load_config(data_manipulation_config_path())
|
||||
network: Network = game.simulation.network
|
||||
|
||||
assert len(network.nodes) == 10 # 10 nodes in example network
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Union
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite.config.load import example_config_path
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.game.agent.data_manipulation_bot import DataManipulationAgent
|
||||
from primaite.game.agent.interface import ProxyAgent, RandomAgent
|
||||
from primaite.game.game import APPLICATION_TYPES_MAPPING, PrimaiteGame, SERVICE_TYPES_MAPPING
|
||||
@@ -37,7 +37,7 @@ def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
|
||||
|
||||
def test_example_config():
|
||||
"""Test that the example config can be parsed properly."""
|
||||
game = load_config(example_config_path())
|
||||
game = load_config(data_manipulation_config_path())
|
||||
|
||||
assert len(game.agents) == 4 # red, blue and 2 green agents
|
||||
|
||||
|
||||
Reference in New Issue
Block a user