Merged PR 228: Fix annoying Ray issues
## Summary Fix ray episode issues. Allow ray to manage its environment fully - this fixes the perma-zero reward. ## Test process Running notebooks and primaite sessions. ## Checklist - [ ] PR is linked to a **work item** - [ ] **acceptance criteria** of linked ticket are met - [ ] performed **self-review** of the code - [ ] written **tests** for any new functionality added with this PR - [ ] updated the **documentation** if this PR changes or adds functionality - [ ] written/updated **design docs** if this PR implements new functionality - [ ] updated the **change log** - [ ] ran **pre-commit** checks for code style - [ ] attended to any **TO-DOs** left in the code Related work items: #2095
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -155,3 +155,4 @@ simulation_output/
|
||||
benchmark/output
|
||||
# src/primaite/notebooks/scratch.ipynb
|
||||
src/primaite/notebooks/scratch.py
|
||||
sandbox.py
|
||||
|
||||
@@ -655,8 +655,8 @@ simulation:
|
||||
- ref: data_manipulation_bot
|
||||
type: DataManipulationBot
|
||||
options:
|
||||
port_scan_p_of_success: 0.1
|
||||
data_manipulation_p_of_success: 0.1
|
||||
port_scan_p_of_success: 0.8
|
||||
data_manipulation_p_of_success: 0.8
|
||||
payload: "DELETE"
|
||||
server_ip: 192.168.1.14
|
||||
services:
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
training_config:
|
||||
rl_framework: RLLIB_single_agent
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 1
|
||||
n_eval_episodes: 5
|
||||
max_steps_per_episode: 256
|
||||
deterministic_eval: false
|
||||
n_agents: 1
|
||||
rl_framework: RLLIB_multi_agent
|
||||
# rl_framework: SB3
|
||||
n_agents: 2
|
||||
agent_references:
|
||||
- defender
|
||||
- defender_1
|
||||
- defender_2
|
||||
|
||||
io_settings:
|
||||
save_checkpoints: true
|
||||
@@ -36,31 +32,26 @@ agents:
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
# <not yet implemented>
|
||||
# - type: NODE_LOGON
|
||||
# - type: NODE_LOGOFF
|
||||
# - type: NODE_APPLICATION_EXECUTE
|
||||
# options:
|
||||
# execution_definition:
|
||||
# target_address: arcd.com
|
||||
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: client_2
|
||||
applications:
|
||||
- application_ref: client_2_web_browser
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_nics_per_node: 2
|
||||
max_acl_rules: 10
|
||||
max_applications_per_node: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
start_step: 5
|
||||
frequency: 4
|
||||
variance: 3
|
||||
start_settings:
|
||||
start_step: 5
|
||||
frequency: 4
|
||||
variance: 3
|
||||
|
||||
- ref: client_1_data_manipulation_red_bot
|
||||
team: RED
|
||||
@@ -69,38 +60,20 @@ agents:
|
||||
observation_space:
|
||||
type: UC2RedObservation
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: client_1
|
||||
observations:
|
||||
- logon_status
|
||||
- operating_status
|
||||
services:
|
||||
- service_ref: data_manipulation_bot
|
||||
observations:
|
||||
operating_status
|
||||
health_status
|
||||
folders: {}
|
||||
nodes: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
#<not yet implemented
|
||||
# - type: NODE_APPLICATION_EXECUTE
|
||||
# options:
|
||||
# execution_definition:
|
||||
# server_ip: 192.168.1.14
|
||||
# payload: "DELETE"
|
||||
# success_rate: 80%
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
- type: NODE_FILE_DELETE
|
||||
- type: NODE_FILE_CORRUPT
|
||||
# - type: NODE_FOLDER_DELETE
|
||||
# - type: NODE_FOLDER_CORRUPT
|
||||
- type: NODE_OS_SCAN
|
||||
# - type: NODE_LOGON
|
||||
# - type: NODE_LOGOFF
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: client_1
|
||||
applications:
|
||||
- application_ref: data_manipulation_bot
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
@@ -110,11 +83,12 @@ agents:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
|
||||
start_step: 25
|
||||
frequency: 20
|
||||
variance: 5
|
||||
start_settings:
|
||||
start_step: 25
|
||||
frequency: 20
|
||||
variance: 5
|
||||
|
||||
- ref: defender1
|
||||
- ref: defender_1
|
||||
team: BLUE
|
||||
type: ProxyAgent
|
||||
|
||||
@@ -544,7 +518,9 @@ agents:
|
||||
|
||||
agent_settings:
|
||||
# ...
|
||||
- ref: defender2
|
||||
|
||||
|
||||
- ref: defender_2
|
||||
team: BLUE
|
||||
type: ProxyAgent
|
||||
|
||||
@@ -992,17 +968,25 @@ simulation:
|
||||
ip_address: 192.168.1.1
|
||||
subnet_mask: 255.255.255.0
|
||||
2:
|
||||
ip_address: 192.168.1.1
|
||||
ip_address: 192.168.10.1
|
||||
subnet_mask: 255.255.255.0
|
||||
acl:
|
||||
0:
|
||||
18:
|
||||
action: PERMIT
|
||||
src_port: POSTGRES_SERVER
|
||||
dst_port: POSTGRES_SERVER
|
||||
1:
|
||||
19:
|
||||
action: PERMIT
|
||||
src_port: DNS
|
||||
dst_port: DNS
|
||||
20:
|
||||
action: PERMIT
|
||||
src_port: FTP
|
||||
dst_port: FTP
|
||||
21:
|
||||
action: PERMIT
|
||||
src_port: HTTP
|
||||
dst_port: HTTP
|
||||
22:
|
||||
action: PERMIT
|
||||
src_port: ARP
|
||||
@@ -1039,7 +1023,7 @@ simulation:
|
||||
hostname: web_server
|
||||
ip_address: 192.168.1.12
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.10
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: web_server_database_client
|
||||
@@ -1060,6 +1044,10 @@ simulation:
|
||||
services:
|
||||
- ref: database_service
|
||||
type: DatabaseService
|
||||
options:
|
||||
backup_server_ip: 192.168.1.16
|
||||
- ref: database_ftp_client
|
||||
type: FTPClient
|
||||
|
||||
- ref: backup_server
|
||||
type: server
|
||||
@@ -1070,7 +1058,7 @@ simulation:
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: backup_service
|
||||
type: DatabaseBackup
|
||||
type: FTPServer
|
||||
|
||||
- ref: security_suite
|
||||
type: server
|
||||
@@ -1091,9 +1079,15 @@ simulation:
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
applications:
|
||||
- ref: data_manipulation_bot
|
||||
type: DataManipulationBot
|
||||
options:
|
||||
port_scan_p_of_success: 0.1
|
||||
data_manipulation_p_of_success: 0.1
|
||||
payload: "DELETE"
|
||||
server_ip: 192.168.1.14
|
||||
services:
|
||||
- ref: client_1_dns_client
|
||||
type: DNSClient
|
||||
|
||||
@@ -1107,10 +1101,14 @@ simulation:
|
||||
applications:
|
||||
- ref: client_2_web_browser
|
||||
type: WebBrowser
|
||||
options:
|
||||
target_url: http://arcd.com/users/
|
||||
services:
|
||||
- ref: client_2_dns_client
|
||||
type: DNSClient
|
||||
|
||||
|
||||
|
||||
links:
|
||||
- ref: router_1___switch_1
|
||||
endpoint_a_ref: router_1
|
||||
|
||||
@@ -238,7 +238,8 @@ class RewardFunction:
|
||||
"""Initialise the reward function object."""
|
||||
self.reward_components: List[Tuple[AbstractReward, float]] = []
|
||||
"attribute reward_components keeps track of reward components and the weights assigned to each."
|
||||
self.current_reward: float
|
||||
self.current_reward: float = 0.0
|
||||
self.total_reward: float = 0.0
|
||||
|
||||
def regsiter_component(self, component: AbstractReward, weight: float = 1.0) -> None:
|
||||
"""Add a reward component to the reward function.
|
||||
|
||||
@@ -125,6 +125,7 @@ class PrimaiteGame:
|
||||
for agent in self.agents:
|
||||
agent.update_observation(state)
|
||||
agent.update_reward(state)
|
||||
agent.reward_function.total_reward += agent.reward_function.current_reward
|
||||
|
||||
def apply_agent_actions(self) -> None:
|
||||
"""Apply all actions to simulation as requests."""
|
||||
@@ -155,6 +156,8 @@ class PrimaiteGame:
|
||||
self.step_counter = 0
|
||||
_LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}")
|
||||
self.simulation.reset_component_for_episode(episode=self.episode_counter)
|
||||
for agent in self.agents:
|
||||
agent.reward_function.total_reward = 0.0
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the game, this will close the simulation."""
|
||||
@@ -240,7 +243,7 @@ class PrimaiteGame:
|
||||
position=r_num,
|
||||
)
|
||||
else:
|
||||
print("invalid node type")
|
||||
_LOGGER.warning(f"invalid node type {n_type} in config")
|
||||
if "services" in node_cfg:
|
||||
for service_cfg in node_cfg["services"]:
|
||||
new_service = None
|
||||
@@ -256,12 +259,12 @@ class PrimaiteGame:
|
||||
"FTPServer": FTPServer,
|
||||
}
|
||||
if service_type in service_types_mapping:
|
||||
print(f"installing {service_type} on node {new_node.hostname}")
|
||||
_LOGGER.debug(f"installing {service_type} on node {new_node.hostname}")
|
||||
new_node.software_manager.install(service_types_mapping[service_type])
|
||||
new_service = new_node.software_manager.software[service_type]
|
||||
game.ref_map_services[service_ref] = new_service.uuid
|
||||
else:
|
||||
print(f"service type not found {service_type}")
|
||||
_LOGGER.warning(f"service type not found {service_type}")
|
||||
# service-dependent options
|
||||
if service_type == "DatabaseClient":
|
||||
if "options" in service_cfg:
|
||||
@@ -295,7 +298,7 @@ class PrimaiteGame:
|
||||
new_application = new_node.software_manager.software[application_type]
|
||||
game.ref_map_applications[application_ref] = new_application.uuid
|
||||
else:
|
||||
print(f"application type not found {application_type}")
|
||||
_LOGGER.warning(f"application type not found {application_type}")
|
||||
|
||||
if application_type == "DataManipulationBot":
|
||||
if "options" in application_cfg:
|
||||
@@ -416,7 +419,7 @@ class PrimaiteGame:
|
||||
)
|
||||
game.agents.append(new_agent)
|
||||
else:
|
||||
print("agent type not found")
|
||||
_LOGGER.warning(f"agent type {agent_type} not found")
|
||||
|
||||
game.simulation.set_original_state()
|
||||
|
||||
|
||||
@@ -1,5 +1,21 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Train a Multi agent system using RLLIB\n",
|
||||
"\n",
|
||||
"This notebook will demonstrate how to use the `PrimaiteRayMARLEnv` to train a very basic system with two PPO agents."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### First, Import packages and read our config file."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -8,75 +24,56 @@
|
||||
"source": [
|
||||
"from primaite.game.game import PrimaiteGame\n",
|
||||
"import yaml\n",
|
||||
"from primaite.config.load import example_config_path\n",
|
||||
"\n",
|
||||
"from primaite.session.environment import PrimaiteRayEnv"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(example_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
"from primaite.session.environment import PrimaiteRayEnv\n",
|
||||
"from primaite import PRIMAITE_PATHS\n",
|
||||
"\n",
|
||||
"game = PrimaiteGame.from_config(cfg)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# gym = PrimaiteRayEnv({\"game\":game})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import ray\n",
|
||||
"from ray import air, tune\n",
|
||||
"from ray.rllib.algorithms.ppo import PPOConfig"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ray.shutdown()\n",
|
||||
"ray.init()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from ray.rllib.algorithms.ppo import PPOConfig\n",
|
||||
"from primaite.session.environment import PrimaiteRayMARLEnv\n",
|
||||
"\n",
|
||||
"# If you get an error saying this config file doesn't exist, you may need to run `primaite setup` in your command line\n",
|
||||
"# to copy the files to your user data path.\n",
|
||||
"with open(PRIMAITE_PATHS.user_config_path / 'example_config/example_config_2_rl_agents.yaml', 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
"\n",
|
||||
"env_config = {\"game\":game}\n",
|
||||
"ray.init(local_mode=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Create a Ray algorithm config which accepts our two agents"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"config = (\n",
|
||||
" PPOConfig()\n",
|
||||
" .environment(env=PrimaiteRayMARLEnv, env_config={\"game\":game})\n",
|
||||
" .rollouts(num_rollout_workers=0)\n",
|
||||
" .multi_agent(\n",
|
||||
" policies={agent.agent_name for agent in game.rl_agents},\n",
|
||||
" policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n",
|
||||
" policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n",
|
||||
" )\n",
|
||||
" .environment(env=PrimaiteRayMARLEnv, env_config={\"cfg\":cfg})#, disable_env_checking=True)\n",
|
||||
" .rollouts(num_rollout_workers=0)\n",
|
||||
" .training(train_batch_size=128)\n",
|
||||
" )\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Set training parameters and start the training\n",
|
||||
"This example will save outputs to a default Ray directory and use mostly default settings."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -86,21 +83,11 @@
|
||||
"tune.Tuner(\n",
|
||||
" \"PPO\",\n",
|
||||
" run_config=air.RunConfig(\n",
|
||||
" stop={\"training_iteration\": 128},\n",
|
||||
" checkpoint_config=air.CheckpointConfig(\n",
|
||||
" checkpoint_frequency=10,\n",
|
||||
" ),\n",
|
||||
" stop={\"timesteps_total\": 512},\n",
|
||||
" ),\n",
|
||||
" param_space=config\n",
|
||||
").fit()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Train a Single agent system using RLLib\n",
|
||||
"This notebook will demonstrate how to use PrimaiteRayEnv to train a basic PPO agent."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -10,19 +18,25 @@
|
||||
"import yaml\n",
|
||||
"from primaite.config.load import example_config_path\n",
|
||||
"\n",
|
||||
"from primaite.session.environment import PrimaiteRayEnv"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.session.environment import PrimaiteRayEnv\n",
|
||||
"from ray.rllib.algorithms import ppo\n",
|
||||
"from ray import air, tune\n",
|
||||
"import ray\n",
|
||||
"from ray.rllib.algorithms.ppo import PPOConfig\n",
|
||||
"\n",
|
||||
"# If you get an error saying this config file doesn't exist, you may need to run `primaite setup` in your command line\n",
|
||||
"# to copy the files to your user data path.\n",
|
||||
"with open(example_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
"\n",
|
||||
"game = PrimaiteGame.from_config(cfg)"
|
||||
"ray.init(local_mode=True)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Create a Ray algorithm and pass it our config."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -31,7 +45,21 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gym = PrimaiteRayEnv({\"game\":game})"
|
||||
"env_config = {\"cfg\":cfg}\n",
|
||||
"\n",
|
||||
"config = (\n",
|
||||
" PPOConfig()\n",
|
||||
" .environment(env=PrimaiteRayEnv, env_config=env_config, disable_env_checking=True)\n",
|
||||
" .rollouts(num_rollout_workers=0)\n",
|
||||
" .training(train_batch_size=128)\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Set training parameters and start the training"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -40,61 +68,13 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import ray\n",
|
||||
"from ray.rllib.algorithms import ppo"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ray.shutdown()\n",
|
||||
"ray.init()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env_config = {\"game\":game}\n",
|
||||
"config = {\n",
|
||||
" \"env\" : PrimaiteRayEnv,\n",
|
||||
" \"env_config\" : env_config,\n",
|
||||
" \"disable_env_checking\": True,\n",
|
||||
" \"num_rollout_workers\": 0,\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"algo = ppo.PPO(config=config)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for i in range(5):\n",
|
||||
" result = algo.train()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"algo.save(\"temp/deleteme\")"
|
||||
"tune.Tuner(\n",
|
||||
" \"PPO\",\n",
|
||||
" run_config=air.RunConfig(\n",
|
||||
" stop={\"timesteps_total\": 512}\n",
|
||||
" ),\n",
|
||||
" param_space=config\n",
|
||||
").fit()\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
@@ -37,11 +37,14 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
terminated = False
|
||||
truncated = self.game.calculate_truncated()
|
||||
info = {}
|
||||
print(f"Episode: {self.game.episode_counter}, Step: {self.game.step_counter}, Reward: {reward}")
|
||||
return next_obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
|
||||
"""Reset the environment."""
|
||||
print(
|
||||
f"Resetting environment, episode {self.game.episode_counter}, "
|
||||
f"avg. reward: {self.game.rl_agents[0].reward_function.total_reward}"
|
||||
)
|
||||
self.game.reset()
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state)
|
||||
@@ -69,14 +72,15 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
class PrimaiteRayEnv(gymnasium.Env):
|
||||
"""Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray."""
|
||||
|
||||
def __init__(self, env_config: Dict[str, PrimaiteGame]) -> None:
|
||||
def __init__(self, env_config: Dict) -> None:
|
||||
"""Initialise the environment.
|
||||
|
||||
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
|
||||
which is the PrimaiteGame instance.
|
||||
:type env_config: Dict[str, PrimaiteGame]
|
||||
"""
|
||||
self.env = PrimaiteGymEnv(game=env_config["game"])
|
||||
self.env = PrimaiteGymEnv(game=PrimaiteGame.from_config(env_config["cfg"]))
|
||||
self.env.game.episode_counter -= 1
|
||||
self.action_space = self.env.action_space
|
||||
self.observation_space = self.env.observation_space
|
||||
|
||||
@@ -92,14 +96,14 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
"""Ray Environment that inherits from MultiAgentEnv to allow training MARL systems."""
|
||||
|
||||
def __init__(self, env_config: Optional[Dict] = None) -> None:
|
||||
def __init__(self, env_config: Dict) -> None:
|
||||
"""Initialise the environment.
|
||||
|
||||
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
|
||||
which is the PrimaiteGame instance.
|
||||
:type env_config: Dict[str, PrimaiteGame]
|
||||
"""
|
||||
self.game: PrimaiteGame = env_config["game"]
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(env_config["cfg"])
|
||||
"""Reference to the primaite game"""
|
||||
self.agents: Final[Dict[str, ProxyAgent]] = {agent.agent_name: agent for agent in self.game.rl_agents}
|
||||
"""List of all possible agents in the environment. This list should not change!"""
|
||||
@@ -108,7 +112,10 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
self.terminateds = set()
|
||||
self.truncateds = set()
|
||||
self.observation_space = gymnasium.spaces.Dict(
|
||||
{name: agent.observation_manager.space for name, agent in self.agents.items()}
|
||||
{
|
||||
name: gymnasium.spaces.flatten_space(agent.observation_manager.space)
|
||||
for name, agent in self.agents.items()
|
||||
}
|
||||
)
|
||||
self.action_space = gymnasium.spaces.Dict(
|
||||
{name: agent.action_manager.space for name, agent in self.agents.items()}
|
||||
@@ -159,4 +166,9 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
|
||||
def _get_obs(self) -> Dict[str, ObsType]:
|
||||
"""Return the current observation."""
|
||||
return {name: agent.observation_manager.current_observation for name, agent in self.agents.items()}
|
||||
obs = {}
|
||||
for name, agent in self.agents.items():
|
||||
unflat_space = agent.observation_manager.space
|
||||
unflat_obs = agent.observation_manager.current_observation
|
||||
obs[name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
return obs
|
||||
|
||||
@@ -12,6 +12,10 @@ from ray import air, tune
|
||||
from ray.rllib.algorithms import ppo
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
|
||||
"""Single agent RL policy using Ray RLLib."""
|
||||
@@ -19,7 +23,7 @@ class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
|
||||
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
|
||||
super().__init__(session=session)
|
||||
|
||||
config = {
|
||||
self.config = {
|
||||
"env": PrimaiteRayEnv,
|
||||
"env_config": {"game": session.game},
|
||||
"disable_env_checking": True,
|
||||
@@ -29,12 +33,13 @@ class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
|
||||
ray.shutdown()
|
||||
ray.init()
|
||||
|
||||
self._algo = ppo.PPO(config=config)
|
||||
|
||||
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
|
||||
"""Train the agent."""
|
||||
for ep in range(n_episodes):
|
||||
self._algo.train()
|
||||
self.config["training_iterations"] = n_episodes * timesteps_per_episode
|
||||
self.config["train_batch_size"] = 128
|
||||
self._algo = ppo.PPO(config=self.config)
|
||||
_LOGGER.info("Starting RLLIB training session")
|
||||
self._algo.train()
|
||||
|
||||
def eval(self, n_episodes: int, deterministic: bool) -> None:
|
||||
"""Evaluate the agent."""
|
||||
|
||||
@@ -51,14 +51,13 @@ class SB3Policy(PolicyABC, identifier="SB3"):
|
||||
|
||||
def eval(self, n_episodes: int, deterministic: bool) -> None:
|
||||
"""Evaluate the agent."""
|
||||
reward_data = evaluate_policy(
|
||||
_ = evaluate_policy(
|
||||
self._agent,
|
||||
self.session.env,
|
||||
n_eval_episodes=n_episodes,
|
||||
deterministic=deterministic,
|
||||
return_episode_rewards=True,
|
||||
)
|
||||
print(reward_data)
|
||||
|
||||
def save(self, save_path: Path) -> None:
|
||||
"""
|
||||
|
||||
@@ -62,6 +62,7 @@ class PrimaiteSession:
|
||||
|
||||
def start_session(self) -> None:
|
||||
"""Commence the training/eval session."""
|
||||
print("Starting Primaite Session")
|
||||
self.mode = SessionMode.TRAIN
|
||||
n_learn_episodes = self.training_options.n_learn_episodes
|
||||
n_eval_episodes = self.training_options.n_eval_episodes
|
||||
|
||||
@@ -113,7 +113,7 @@ class RequestManager(BaseModel):
|
||||
"""
|
||||
if name in self.request_types:
|
||||
msg = f"Overwriting request type {name}."
|
||||
_LOGGER.warning(msg)
|
||||
_LOGGER.debug(msg)
|
||||
|
||||
self.request_types[name] = request_type
|
||||
|
||||
|
||||
@@ -220,7 +220,7 @@ class Network(SimComponent):
|
||||
self._node_id_map[len(self.nodes)] = node
|
||||
node.parent = self
|
||||
self._nx_graph.add_node(node.hostname)
|
||||
_LOGGER.info(f"Added node {node.uuid} to Network {self.uuid}")
|
||||
_LOGGER.debug(f"Added node {node.uuid} to Network {self.uuid}")
|
||||
self._node_request_manager.add_request(name=node.uuid, request_type=RequestType(func=node._request_manager))
|
||||
|
||||
def get_node_by_hostname(self, hostname: str) -> Optional[Node]:
|
||||
|
||||
@@ -181,13 +181,13 @@ class NIC(SimComponent):
|
||||
if self.enabled:
|
||||
return
|
||||
if not self._connected_node:
|
||||
_LOGGER.error(f"NIC {self} cannot be enabled as it is not connected to a Node")
|
||||
_LOGGER.debug(f"NIC {self} cannot be enabled as it is not connected to a Node")
|
||||
return
|
||||
if self._connected_node.operating_state != NodeOperatingState.ON:
|
||||
self._connected_node.sys_log.error(f"NIC {self} cannot be enabled as the endpoint is not turned on")
|
||||
return
|
||||
if not self._connected_link:
|
||||
_LOGGER.error(f"NIC {self} cannot be enabled as it is not connected to a Link")
|
||||
_LOGGER.debug(f"NIC {self} cannot be enabled as it is not connected to a Link")
|
||||
return
|
||||
|
||||
self.enabled = True
|
||||
|
||||
@@ -56,7 +56,7 @@ class DatabaseService(Service):
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
print("Resetting DatabaseService original state on node {self.software_manager.node.hostname}")
|
||||
_LOGGER.debug("Resetting DatabaseService original state on node {self.software_manager.node.hostname}")
|
||||
self.connections.clear()
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
|
||||
@@ -47,7 +47,6 @@ class WebServer(Service):
|
||||
state["last_response_status_code"] = (
|
||||
self.last_response_status_code.value if isinstance(self.last_response_status_code, HttpStatusCode) else None
|
||||
)
|
||||
print(state)
|
||||
return state
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user