Merged PR 228: Fix annoying Ray issues

## Summary
Fix ray episode issues. Allow ray to manage its environment fully - this fixes the perma-zero reward.

## Test process
Running notebooks and primaite sessions.

## Checklist
- [ ] PR is linked to a **work item**
- [ ] **acceptance criteria** of linked ticket are met
- [ ] performed **self-review** of the code
- [ ] written **tests** for any new functionality added with this PR
- [ ] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [ ] updated the **change log**
- [ ] ran **pre-commit** checks for code style
- [ ] attended to any **TO-DOs** left in the code

Related work items: #2095
This commit is contained in:
Marek Wolan
2023-12-04 10:10:10 +00:00
16 changed files with 199 additions and 213 deletions

1
.gitignore vendored
View File

@@ -155,3 +155,4 @@ simulation_output/
benchmark/output
# src/primaite/notebooks/scratch.ipynb
src/primaite/notebooks/scratch.py
sandbox.py

View File

@@ -655,8 +655,8 @@ simulation:
- ref: data_manipulation_bot
type: DataManipulationBot
options:
port_scan_p_of_success: 0.1
data_manipulation_p_of_success: 0.1
port_scan_p_of_success: 0.8
data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.14
services:

View File

@@ -1,14 +1,10 @@
training_config:
rl_framework: RLLIB_single_agent
rl_algorithm: PPO
seed: 333
n_learn_episodes: 1
n_eval_episodes: 5
max_steps_per_episode: 256
deterministic_eval: false
n_agents: 1
rl_framework: RLLIB_multi_agent
# rl_framework: SB3
n_agents: 2
agent_references:
- defender
- defender_1
- defender_2
io_settings:
save_checkpoints: true
@@ -36,31 +32,26 @@ agents:
action_space:
action_list:
- type: DONOTHING
# <not yet implemented>
# - type: NODE_LOGON
# - type: NODE_LOGOFF
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# target_address: arcd.com
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_ref: client_2
applications:
- application_ref: client_2_web_browser
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_nics_per_node: 2
max_acl_rules: 10
max_applications_per_node: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_step: 5
frequency: 4
variance: 3
start_settings:
start_step: 5
frequency: 4
variance: 3
- ref: client_1_data_manipulation_red_bot
team: RED
@@ -69,38 +60,20 @@ agents:
observation_space:
type: UC2RedObservation
options:
nodes:
- node_ref: client_1
observations:
- logon_status
- operating_status
services:
- service_ref: data_manipulation_bot
observations:
operating_status
health_status
folders: {}
nodes: {}
action_space:
action_list:
- type: DONOTHING
#<not yet implemented
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# server_ip: 192.168.1.14
# payload: "DELETE"
# success_rate: 80%
- type: NODE_APPLICATION_EXECUTE
- type: NODE_FILE_DELETE
- type: NODE_FILE_CORRUPT
# - type: NODE_FOLDER_DELETE
# - type: NODE_FOLDER_CORRUPT
- type: NODE_OS_SCAN
# - type: NODE_LOGON
# - type: NODE_LOGOFF
options:
nodes:
- node_ref: client_1
applications:
- application_ref: data_manipulation_bot
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
@@ -110,11 +83,12 @@ agents:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_step: 25
frequency: 20
variance: 5
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: defender1
- ref: defender_1
team: BLUE
type: ProxyAgent
@@ -544,7 +518,9 @@ agents:
agent_settings:
# ...
- ref: defender2
- ref: defender_2
team: BLUE
type: ProxyAgent
@@ -992,17 +968,25 @@ simulation:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
2:
ip_address: 192.168.1.1
ip_address: 192.168.10.1
subnet_mask: 255.255.255.0
acl:
0:
18:
action: PERMIT
src_port: POSTGRES_SERVER
dst_port: POSTGRES_SERVER
1:
19:
action: PERMIT
src_port: DNS
dst_port: DNS
20:
action: PERMIT
src_port: FTP
dst_port: FTP
21:
action: PERMIT
src_port: HTTP
dst_port: HTTP
22:
action: PERMIT
src_port: ARP
@@ -1039,7 +1023,7 @@ simulation:
hostname: web_server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.10
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: web_server_database_client
@@ -1060,6 +1044,10 @@ simulation:
services:
- ref: database_service
type: DatabaseService
options:
backup_server_ip: 192.168.1.16
- ref: database_ftp_client
type: FTPClient
- ref: backup_server
type: server
@@ -1070,7 +1058,7 @@ simulation:
dns_server: 192.168.1.10
services:
- ref: backup_service
type: DatabaseBackup
type: FTPServer
- ref: security_suite
type: server
@@ -1091,9 +1079,15 @@ simulation:
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
services:
applications:
- ref: data_manipulation_bot
type: DataManipulationBot
options:
port_scan_p_of_success: 0.1
data_manipulation_p_of_success: 0.1
payload: "DELETE"
server_ip: 192.168.1.14
services:
- ref: client_1_dns_client
type: DNSClient
@@ -1107,10 +1101,14 @@ simulation:
applications:
- ref: client_2_web_browser
type: WebBrowser
options:
target_url: http://arcd.com/users/
services:
- ref: client_2_dns_client
type: DNSClient
links:
- ref: router_1___switch_1
endpoint_a_ref: router_1

View File

@@ -238,7 +238,8 @@ class RewardFunction:
"""Initialise the reward function object."""
self.reward_components: List[Tuple[AbstractReward, float]] = []
"attribute reward_components keeps track of reward components and the weights assigned to each."
self.current_reward: float
self.current_reward: float = 0.0
self.total_reward: float = 0.0
def regsiter_component(self, component: AbstractReward, weight: float = 1.0) -> None:
"""Add a reward component to the reward function.

View File

@@ -125,6 +125,7 @@ class PrimaiteGame:
for agent in self.agents:
agent.update_observation(state)
agent.update_reward(state)
agent.reward_function.total_reward += agent.reward_function.current_reward
def apply_agent_actions(self) -> None:
"""Apply all actions to simulation as requests."""
@@ -155,6 +156,8 @@ class PrimaiteGame:
self.step_counter = 0
_LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}")
self.simulation.reset_component_for_episode(episode=self.episode_counter)
for agent in self.agents:
agent.reward_function.total_reward = 0.0
def close(self) -> None:
"""Close the game, this will close the simulation."""
@@ -240,7 +243,7 @@ class PrimaiteGame:
position=r_num,
)
else:
print("invalid node type")
_LOGGER.warning(f"invalid node type {n_type} in config")
if "services" in node_cfg:
for service_cfg in node_cfg["services"]:
new_service = None
@@ -256,12 +259,12 @@ class PrimaiteGame:
"FTPServer": FTPServer,
}
if service_type in service_types_mapping:
print(f"installing {service_type} on node {new_node.hostname}")
_LOGGER.debug(f"installing {service_type} on node {new_node.hostname}")
new_node.software_manager.install(service_types_mapping[service_type])
new_service = new_node.software_manager.software[service_type]
game.ref_map_services[service_ref] = new_service.uuid
else:
print(f"service type not found {service_type}")
_LOGGER.warning(f"service type not found {service_type}")
# service-dependent options
if service_type == "DatabaseClient":
if "options" in service_cfg:
@@ -295,7 +298,7 @@ class PrimaiteGame:
new_application = new_node.software_manager.software[application_type]
game.ref_map_applications[application_ref] = new_application.uuid
else:
print(f"application type not found {application_type}")
_LOGGER.warning(f"application type not found {application_type}")
if application_type == "DataManipulationBot":
if "options" in application_cfg:
@@ -416,7 +419,7 @@ class PrimaiteGame:
)
game.agents.append(new_agent)
else:
print("agent type not found")
_LOGGER.warning(f"agent type {agent_type} not found")
game.simulation.set_original_state()

View File

@@ -1,5 +1,21 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train a Multi agent system using RLLIB\n",
"\n",
"This notebook will demonstrate how to use the `PrimaiteRayMARLEnv` to train a very basic system with two PPO agents."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### First, Import packages and read our config file."
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -8,75 +24,56 @@
"source": [
"from primaite.game.game import PrimaiteGame\n",
"import yaml\n",
"from primaite.config.load import example_config_path\n",
"\n",
"from primaite.session.environment import PrimaiteRayEnv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(example_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"from primaite.session.environment import PrimaiteRayEnv\n",
"from primaite import PRIMAITE_PATHS\n",
"\n",
"game = PrimaiteGame.from_config(cfg)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# gym = PrimaiteRayEnv({\"game\":game})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ray\n",
"from ray import air, tune\n",
"from ray.rllib.algorithms.ppo import PPOConfig"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ray.shutdown()\n",
"ray.init()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from ray.rllib.algorithms.ppo import PPOConfig\n",
"from primaite.session.environment import PrimaiteRayMARLEnv\n",
"\n",
"# If you get an error saying this config file doesn't exist, you may need to run `primaite setup` in your command line\n",
"# to copy the files to your user data path.\n",
"with open(PRIMAITE_PATHS.user_config_path / 'example_config/example_config_2_rl_agents.yaml', 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"\n",
"env_config = {\"game\":game}\n",
"ray.init(local_mode=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Create a Ray algorithm config which accepts our two agents"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"config = (\n",
" PPOConfig()\n",
" .environment(env=PrimaiteRayMARLEnv, env_config={\"game\":game})\n",
" .rollouts(num_rollout_workers=0)\n",
" .multi_agent(\n",
" policies={agent.agent_name for agent in game.rl_agents},\n",
" policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n",
" policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n",
" )\n",
" .environment(env=PrimaiteRayMARLEnv, env_config={\"cfg\":cfg})#, disable_env_checking=True)\n",
" .rollouts(num_rollout_workers=0)\n",
" .training(train_batch_size=128)\n",
" )\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Set training parameters and start the training\n",
"This example will save outputs to a default Ray directory and use mostly default settings."
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -86,21 +83,11 @@
"tune.Tuner(\n",
" \"PPO\",\n",
" run_config=air.RunConfig(\n",
" stop={\"training_iteration\": 128},\n",
" checkpoint_config=air.CheckpointConfig(\n",
" checkpoint_frequency=10,\n",
" ),\n",
" stop={\"timesteps_total\": 512},\n",
" ),\n",
" param_space=config\n",
").fit()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

View File

@@ -1,5 +1,13 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train a Single agent system using RLLib\n",
"This notebook will demonstrate how to use PrimaiteRayEnv to train a basic PPO agent."
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -10,19 +18,25 @@
"import yaml\n",
"from primaite.config.load import example_config_path\n",
"\n",
"from primaite.session.environment import PrimaiteRayEnv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.session.environment import PrimaiteRayEnv\n",
"from ray.rllib.algorithms import ppo\n",
"from ray import air, tune\n",
"import ray\n",
"from ray.rllib.algorithms.ppo import PPOConfig\n",
"\n",
"# If you get an error saying this config file doesn't exist, you may need to run `primaite setup` in your command line\n",
"# to copy the files to your user data path.\n",
"with open(example_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"\n",
"game = PrimaiteGame.from_config(cfg)"
"ray.init(local_mode=True)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Create a Ray algorithm and pass it our config."
]
},
{
@@ -31,7 +45,21 @@
"metadata": {},
"outputs": [],
"source": [
"gym = PrimaiteRayEnv({\"game\":game})"
"env_config = {\"cfg\":cfg}\n",
"\n",
"config = (\n",
" PPOConfig()\n",
" .environment(env=PrimaiteRayEnv, env_config=env_config, disable_env_checking=True)\n",
" .rollouts(num_rollout_workers=0)\n",
" .training(train_batch_size=128)\n",
")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Set training parameters and start the training"
]
},
{
@@ -40,61 +68,13 @@
"metadata": {},
"outputs": [],
"source": [
"import ray\n",
"from ray.rllib.algorithms import ppo"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ray.shutdown()\n",
"ray.init()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env_config = {\"game\":game}\n",
"config = {\n",
" \"env\" : PrimaiteRayEnv,\n",
" \"env_config\" : env_config,\n",
" \"disable_env_checking\": True,\n",
" \"num_rollout_workers\": 0,\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"algo = ppo.PPO(config=config)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for i in range(5):\n",
" result = algo.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"algo.save(\"temp/deleteme\")"
"tune.Tuner(\n",
" \"PPO\",\n",
" run_config=air.RunConfig(\n",
" stop={\"timesteps_total\": 512}\n",
" ),\n",
" param_space=config\n",
").fit()\n"
]
}
],

View File

@@ -37,11 +37,14 @@ class PrimaiteGymEnv(gymnasium.Env):
terminated = False
truncated = self.game.calculate_truncated()
info = {}
print(f"Episode: {self.game.episode_counter}, Step: {self.game.step_counter}, Reward: {reward}")
return next_obs, reward, terminated, truncated, info
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
"""Reset the environment."""
print(
f"Resetting environment, episode {self.game.episode_counter}, "
f"avg. reward: {self.game.rl_agents[0].reward_function.total_reward}"
)
self.game.reset()
state = self.game.get_sim_state()
self.game.update_agents(state)
@@ -69,14 +72,15 @@ class PrimaiteGymEnv(gymnasium.Env):
class PrimaiteRayEnv(gymnasium.Env):
"""Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray."""
def __init__(self, env_config: Dict[str, PrimaiteGame]) -> None:
def __init__(self, env_config: Dict) -> None:
"""Initialise the environment.
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
which is the PrimaiteGame instance.
:type env_config: Dict[str, PrimaiteGame]
"""
self.env = PrimaiteGymEnv(game=env_config["game"])
self.env = PrimaiteGymEnv(game=PrimaiteGame.from_config(env_config["cfg"]))
self.env.game.episode_counter -= 1
self.action_space = self.env.action_space
self.observation_space = self.env.observation_space
@@ -92,14 +96,14 @@ class PrimaiteRayEnv(gymnasium.Env):
class PrimaiteRayMARLEnv(MultiAgentEnv):
"""Ray Environment that inherits from MultiAgentEnv to allow training MARL systems."""
def __init__(self, env_config: Optional[Dict] = None) -> None:
def __init__(self, env_config: Dict) -> None:
"""Initialise the environment.
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
which is the PrimaiteGame instance.
:type env_config: Dict[str, PrimaiteGame]
"""
self.game: PrimaiteGame = env_config["game"]
self.game: PrimaiteGame = PrimaiteGame.from_config(env_config["cfg"])
"""Reference to the primaite game"""
self.agents: Final[Dict[str, ProxyAgent]] = {agent.agent_name: agent for agent in self.game.rl_agents}
"""List of all possible agents in the environment. This list should not change!"""
@@ -108,7 +112,10 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
self.terminateds = set()
self.truncateds = set()
self.observation_space = gymnasium.spaces.Dict(
{name: agent.observation_manager.space for name, agent in self.agents.items()}
{
name: gymnasium.spaces.flatten_space(agent.observation_manager.space)
for name, agent in self.agents.items()
}
)
self.action_space = gymnasium.spaces.Dict(
{name: agent.action_manager.space for name, agent in self.agents.items()}
@@ -159,4 +166,9 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
def _get_obs(self) -> Dict[str, ObsType]:
"""Return the current observation."""
return {name: agent.observation_manager.current_observation for name, agent in self.agents.items()}
obs = {}
for name, agent in self.agents.items():
unflat_space = agent.observation_manager.space
unflat_obs = agent.observation_manager.current_observation
obs[name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
return obs

View File

@@ -12,6 +12,10 @@ from ray import air, tune
from ray.rllib.algorithms import ppo
from ray.rllib.algorithms.ppo import PPOConfig
from primaite import getLogger
_LOGGER = getLogger(__name__)
class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
"""Single agent RL policy using Ray RLLib."""
@@ -19,7 +23,7 @@ class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
super().__init__(session=session)
config = {
self.config = {
"env": PrimaiteRayEnv,
"env_config": {"game": session.game},
"disable_env_checking": True,
@@ -29,12 +33,13 @@ class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
ray.shutdown()
ray.init()
self._algo = ppo.PPO(config=config)
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
for ep in range(n_episodes):
self._algo.train()
self.config["training_iterations"] = n_episodes * timesteps_per_episode
self.config["train_batch_size"] = 128
self._algo = ppo.PPO(config=self.config)
_LOGGER.info("Starting RLLIB training session")
self._algo.train()
def eval(self, n_episodes: int, deterministic: bool) -> None:
"""Evaluate the agent."""

View File

@@ -51,14 +51,13 @@ class SB3Policy(PolicyABC, identifier="SB3"):
def eval(self, n_episodes: int, deterministic: bool) -> None:
"""Evaluate the agent."""
reward_data = evaluate_policy(
_ = evaluate_policy(
self._agent,
self.session.env,
n_eval_episodes=n_episodes,
deterministic=deterministic,
return_episode_rewards=True,
)
print(reward_data)
def save(self, save_path: Path) -> None:
"""

View File

@@ -62,6 +62,7 @@ class PrimaiteSession:
def start_session(self) -> None:
"""Commence the training/eval session."""
print("Starting Primaite Session")
self.mode = SessionMode.TRAIN
n_learn_episodes = self.training_options.n_learn_episodes
n_eval_episodes = self.training_options.n_eval_episodes

View File

@@ -113,7 +113,7 @@ class RequestManager(BaseModel):
"""
if name in self.request_types:
msg = f"Overwriting request type {name}."
_LOGGER.warning(msg)
_LOGGER.debug(msg)
self.request_types[name] = request_type

View File

@@ -220,7 +220,7 @@ class Network(SimComponent):
self._node_id_map[len(self.nodes)] = node
node.parent = self
self._nx_graph.add_node(node.hostname)
_LOGGER.info(f"Added node {node.uuid} to Network {self.uuid}")
_LOGGER.debug(f"Added node {node.uuid} to Network {self.uuid}")
self._node_request_manager.add_request(name=node.uuid, request_type=RequestType(func=node._request_manager))
def get_node_by_hostname(self, hostname: str) -> Optional[Node]:

View File

@@ -181,13 +181,13 @@ class NIC(SimComponent):
if self.enabled:
return
if not self._connected_node:
_LOGGER.error(f"NIC {self} cannot be enabled as it is not connected to a Node")
_LOGGER.debug(f"NIC {self} cannot be enabled as it is not connected to a Node")
return
if self._connected_node.operating_state != NodeOperatingState.ON:
self._connected_node.sys_log.error(f"NIC {self} cannot be enabled as the endpoint is not turned on")
return
if not self._connected_link:
_LOGGER.error(f"NIC {self} cannot be enabled as it is not connected to a Link")
_LOGGER.debug(f"NIC {self} cannot be enabled as it is not connected to a Link")
return
self.enabled = True

View File

@@ -56,7 +56,7 @@ class DatabaseService(Service):
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
print("Resetting DatabaseService original state on node {self.software_manager.node.hostname}")
_LOGGER.debug("Resetting DatabaseService original state on node {self.software_manager.node.hostname}")
self.connections.clear()
super().reset_component_for_episode(episode)

View File

@@ -47,7 +47,6 @@ class WebServer(Service):
state["last_response_status_code"] = (
self.last_response_status_code.value if isinstance(self.last_response_status_code, HttpStatusCode) else None
)
print(state)
return state
def __init__(self, **kwargs):