diff --git a/docs/source/config.rst b/docs/source/config.rst index 1bea0671..52748eec 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -29,7 +29,7 @@ The environment config file consists of the following attributes: * SB3 - Stable Baselines3 * RLLIB - Ray RLlib. -* **red_agent_identifier** +* **agent_identifier** This identifies the agent to use for the session. Select from one of the following: diff --git a/src/primaite/VERSION b/src/primaite/VERSION index 0da493b5..3068ee27 100644 --- a/src/primaite/VERSION +++ b/src/primaite/VERSION @@ -1 +1 @@ -2.0.0b1 \ No newline at end of file +2.0.0rc1 \ No newline at end of file diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index 34ad0adb..812072ba 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -1,4 +1,5 @@ import json +import time from abc import ABC, abstractmethod from datetime import datetime from pathlib import Path @@ -12,7 +13,6 @@ from primaite.config import training_config from primaite.config.training_config import TrainingConfig from primaite.environment.primaite_env import Primaite - _LOGGER = getLogger(__name__) @@ -196,50 +196,77 @@ class AgentSessionABC(ABC): pass -class DeterministicAgentSessionABC(AgentSessionABC): - @abstractmethod - def __init__( - self, - training_config_path, - lay_down_config_path - ): - self._training_config_path = training_config_path - self._lay_down_config_path = lay_down_config_path - self._env: Primaite - self._agent = None +class HardCodedAgentSessionABC(AgentSessionABC): + def __init__(self, training_config_path, lay_down_config_path): + super().__init__(training_config_path, lay_down_config_path) + self._setup() - @abstractmethod def _setup(self): + self._env: Primaite = Primaite( + training_config_path=self._training_config_path, + lay_down_config_path=self._lay_down_config_path, + transaction_list=[], + session_path=self.session_path, + timestamp_str=self.timestamp_str + ) + super()._setup() + self._can_learn = False + self._can_evaluate = True + + + def _save_checkpoint(self): pass - @abstractmethod def _get_latest_checkpoint(self): pass def learn( self, time_steps: Optional[int] = None, - episodes: Optional[int] = None + episodes: Optional[int] = None, + **kwargs ): _LOGGER.warning("Deterministic agents cannot learn") @abstractmethod + def _calculate_action(self, obs): + pass + def evaluate( self, time_steps: Optional[int] = None, - episodes: Optional[int] = None + episodes: Optional[int] = None, + **kwargs ): - pass + if not time_steps: + time_steps = self._training_config.num_steps + + if not episodes: + episodes = self._training_config.num_episodes + + for episode in range(episodes): + # Reset env and collect initial observation + obs = self._env.reset() + for step in range(time_steps): + # Calculate action + action = self._calculate_action(obs) + + # Perform the step + obs, reward, done, info = self._env.step(action) + + if done: + break + + # Introduce a delay between steps + time.sleep(self._training_config.time_delay / 1000) + self._env.close() @classmethod - @abstractmethod def load(cls): - pass + _LOGGER.warning("Deterministic agents cannot be loaded") - @abstractmethod def save(self): - pass + _LOGGER.warning("Deterministic agents cannot be saved") - @abstractmethod def export(self): - pass + _LOGGER.warning("Deterministic agents cannot be exported") diff --git a/src/primaite/agents/hardcoded_acl.py b/src/primaite/agents/hardcoded_acl.py new file mode 100644 index 00000000..4ad08f6e --- /dev/null +++ b/src/primaite/agents/hardcoded_acl.py @@ -0,0 +1,376 @@ +import numpy as np + +from primaite.agents.agent import HardCodedAgentSessionABC +from primaite.agents.utils import ( + get_new_action, + get_node_of_ip, + transform_action_acl_enum, + transform_change_obs_readable, +) +from primaite.common.enums import HardCodedAgentView + + +class HardCodedACLAgent(HardCodedAgentSessionABC): + + def _calculate_action(self, obs): + if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC: + # Basic view action using only the current observation + return self._calculate_action_basic_view(obs) + else: + # full view action using observation space, action + # history and reward feedback + return self._calculate_action_full_view(obs) + + def get_blocked_green_iers(self, green_iers, acl, nodes): + blocked_green_iers = {} + + for green_ier_id, green_ier in green_iers.items(): + source_node_id = green_ier.get_source_node_id() + source_node_address = nodes[source_node_id].ip_address + dest_node_id = green_ier.get_dest_node_id() + dest_node_address = nodes[dest_node_id].ip_address + protocol = green_ier.get_protocol() # e.g. 'TCP' + port = green_ier.get_port() + + # Can be blocked by an ACL or by default (no allow rule exists) + if acl.is_blocked(source_node_address, dest_node_address, protocol, + port): + blocked_green_iers[green_ier_id] = green_ier + + return blocked_green_iers + + def get_matching_acl_rules_for_ier(self, ier, acl, nodes): + """ + Get matching ACL rules for an IER. + """ + + source_node_id = ier.get_source_node_id() + source_node_address = nodes[source_node_id].ip_address + dest_node_id = ier.get_dest_node_id() + dest_node_address = nodes[dest_node_id].ip_address + protocol = ier.get_protocol() # e.g. 'TCP' + port = ier.get_port() + + matching_rules = acl.get_relevant_rules(source_node_address, + dest_node_address, protocol, + port) + return matching_rules + + def get_blocking_acl_rules_for_ier(self, ier, acl, nodes): + """ + Get blocking ACL rules for an IER. + Warning: Can return empty dict but IER can still be blocked by default (No ALLOW rule, therefore blocked) + """ + + matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes) + + blocked_rules = {} + for rule_key, rule_value in matching_rules.items(): + if rule_value.get_permission() == "DENY": + blocked_rules[rule_key] = rule_value + + return blocked_rules + + def get_allow_acl_rules_for_ier(self, ier, acl, nodes): + """ + Get all allowing ACL rules for an IER. + """ + + matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes) + + allowed_rules = {} + for rule_key, rule_value in matching_rules.items(): + if rule_value.get_permission() == "ALLOW": + allowed_rules[rule_key] = rule_value + + return allowed_rules + + def get_matching_acl_rules(self, source_node_id, dest_node_id, protocol, + port, acl, + nodes, services_list): + if source_node_id != "ANY": + source_node_address = nodes[str(source_node_id)].ip_address + else: + source_node_address = source_node_id + + if dest_node_id != "ANY": + dest_node_address = nodes[str(dest_node_id)].ip_address + else: + dest_node_address = dest_node_id + + if protocol != "ANY": + protocol = services_list[ + protocol - 1] # -1 as dont have to account for ANY in list of services + + matching_rules = acl.get_relevant_rules(source_node_address, + dest_node_address, protocol, + port) + return matching_rules + + def get_allow_acl_rules(self, source_node_id, dest_node_id, protocol, + port, acl, + nodes, services_list): + matching_rules = self.get_matching_acl_rules(source_node_id, + dest_node_id, + protocol, port, acl, + nodes, + services_list) + + allowed_rules = {} + for rule_key, rule_value in matching_rules.items(): + if rule_value.get_permission() == "ALLOW": + allowed_rules[rule_key] = rule_value + + return allowed_rules + + def get_deny_acl_rules(self, source_node_id, dest_node_id, protocol, port, + acl, + nodes, services_list): + matching_rules = self.get_matching_acl_rules(source_node_id, + dest_node_id, + protocol, port, acl, + nodes, + services_list) + + allowed_rules = {} + for rule_key, rule_value in matching_rules.items(): + if rule_value.get_permission() == "DENY": + allowed_rules[rule_key] = rule_value + + return allowed_rules + + def _calculate_action_full_view(self, obs): + """ + Given an observation and the environment calculate a good acl-based action for the blue agent to take + + Knowledge of just the observation space is insufficient for a perfect solution, as we need to know: + + - Which ACL rules already exist, - otherwise: + - The agent would perminently get stuck in a loop of performing the same action over and over. + (best action is to block something, but its already blocked but doesn't know this) + - The agent would be unable to interact with existing rules (e.g. how would it know to delete a rule, + if it doesnt know what rules exist) + - The Green IERs (optional) - It often needs to know which traffic it should be allowing. For example + in the default config one of the green IERs is blocked by default, but it has no way of knowing this + based on the observation space. Additionally, potentially in the future, once a node state + has been fixed (no longer compromised), it needs a way to know it should reallow traffic. + A RL agent can learn what the green IERs are on its own - but the rule based agent cannot easily do this. + + There doesn't seem like there's much that can be done if an Operating or OS State is compromised + + If a service node becomes compromised there's a decision to make - do we block that service? + Pros: It cannot launch an attack on another node, so the node will not be able to be OVERWHELMED + Cons: Will block a green IER, decreasing the reward + We decide to block the service. + + Potentially a better solution (for the reward) would be to block the incomming traffic from compromised + nodes once a service becomes overwhelmed. However currently the ACL action space has no way of reversing + an overwhelmed state, so we don't do this. + + """ + #obs = convert_to_old_obs(obs) + r_obs = transform_change_obs_readable(obs) + _, _, _, *s = r_obs + + if len(r_obs) == 4: # only 1 service + s = [*s] + + # 1. Check if node is compromised. If so we want to block its outwards services + # a. If it is comprimised check if there's an allow rule we should delete. + # cons: might delete a multi-rule from any source node (ANY -> x) + # b. OPTIONAL (Deny rules not needed): Check if there already exists an existing Deny Rule so not to duplicate + # c. OPTIONAL (no allow rule = blocked): Add a DENY rule + found_action = False + for service_num, service_states in enumerate(s): + for x, service_state in enumerate(service_states): + if service_state == "COMPROMISED": + + action_source_id = x + 1 # +1 as 0 is any + action_destination_id = "ANY" + action_protocol = service_num + 1 # +1 as 0 is any + action_port = "ANY" + + allow_rules = self.get_allow_acl_rules( + action_source_id, + action_destination_id, + action_protocol, + action_port, + self._env.acl, + self._env.nodes, + self._env.services_list, + ) + deny_rules = self.get_deny_acl_rules( + action_source_id, + action_destination_id, + action_protocol, + action_port, + self._env.acl, + self._env.nodes, + self._env.services_list, + ) + if len(allow_rules) > 0: + # Check if there's an allow rule we should delete + rule = list(allow_rules.values())[0] + action_decision = "DELETE" + action_permission = "ALLOW" + action_source_ip = rule.get_source_ip() + action_source_id = int( + get_node_of_ip(action_source_ip, self._env.nodes)) + action_destination_ip = rule.get_dest_ip() + action_destination_id = int( + get_node_of_ip(action_destination_ip, + self._env.nodes)) + action_protocol_name = rule.get_protocol() + action_protocol = ( + self._env.services_list.index( + action_protocol_name) + 1 + ) # convert name e.g. 'TCP' to index + action_port_name = rule.get_port() + action_port = self._env.ports_list.index( + action_port_name) + 1 # convert port name e.g. '80' to index + + found_action = True + break + elif len(deny_rules) > 0: + # TODO OPTIONAL + # If there's already a DENY RULE, that blocks EVERYTHING from the source ip we don't need + # to create another + # Check to see if the DENY rule really blocks everything (ANY) or just a specific rule + continue + else: + # TODO OPTIONAL: Add a DENY rule, optional as by default no allow rule == blocked + action_decision = "CREATE" + action_permission = "DENY" + break + if found_action: + break + + # 2. If NO Node is Comprimised, or the node has already been blocked, check the green IERs and + # add an Allow rule if the green IER is being blocked. + # a. OPTIONAL - NOT IMPLEMENTED (optional as a deny rule does not overwrite an allow rule): + # If there's a DENY rule delete it if: + # - There isn't already a deny rule + # - It doesnt allows a comprimised node to become operational. + # b. Add an ALLOW rule if: + # - There isn't already an allow rule + # - It doesnt allows a comprimised node to become operational + + if not found_action: + # Which Green IERS are blocked + blocked_green_iers = self.get_blocked_green_iers( + self._env.green_iers, self._env.acl, + self._env.nodes) + for ier_key, ier in blocked_green_iers.items(): + + # Which ALLOW rules are allowing this IER (none) + allowing_rules = self.get_allow_acl_rules_for_ier(ier, + self._env.acl, + self._env.nodes) + + # If there are no blocking rules, it may be being blocked by default + # If there is already an allow rule + node_id_to_check = int(ier.get_source_node_id()) + service_name_to_check = ier.get_protocol() + service_id_to_check = self._env.services_list.index( + service_name_to_check) + + # Service state of the the source node in the ier + service_state = s[service_id_to_check][node_id_to_check - 1] + + if len(allowing_rules) == 0 and service_state != "COMPROMISED": + action_decision = "CREATE" + action_permission = "ALLOW" + action_source_id = int(ier.get_source_node_id()) + action_destination_id = int(ier.get_dest_node_id()) + action_protocol_name = ier.get_protocol() + action_protocol = self._env.services_list.index( + action_protocol_name) + 1 # convert name e.g. 'TCP' to index + action_port_name = ier.get_port() + action_port = self._env.ports_list.index( + action_port_name) + 1 # convert port name e.g. '80' to index + + found_action = True + break + + if found_action: + action = [ + action_decision, + action_permission, + action_source_id, + action_destination_id, + action_protocol, + action_port, + ] + action = transform_action_acl_enum(action) + action = get_new_action(action, self._env.action_dict) + else: + # If no good/useful action has been found, just perform a nothing action + action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"] + action = transform_action_acl_enum(action) + action = get_new_action(action, self._env.action_dict) + return action + + def _calculate_action_basic_view(self, obs): + """Given an observation calculate a good acl-based action for the blue agent to take + + Uses ONLY information from the current observation with NO knowledge of previous actions taken and + NO reward feedback. + + We rely on randomness to select the precise action, as we want to block all traffic originating from + a compromised node, without being able to tell: + 1. Which ACL rules already exist + 1. Which actions the agent has already tried. + + There is a high probability that the correct rule will not be deleted before the state becomes overwhelmed. + + Currently a deny rule does not overwrite an allow rule. The allow rules must be deleted. + """ + action_dict = self._env.action_dict + r_obs = transform_change_obs_readable(obs) + _, o, _, *s = r_obs + + if len(r_obs) == 4: # only 1 service + s = [*s] + + number_of_nodes = len( + [i for i in o if i != "NONE"]) # number of nodes (not links) + for service_num, service_states in enumerate(s): + comprimised_states = [n for n, i in enumerate(service_states) if + i == "COMPROMISED"] + if len(comprimised_states) == 0: + # No states are COMPROMISED, try the next service + continue + + compromised_node = np.random.choice( + comprimised_states) + 1 # +1 as 0 would be any + action_decision = "DELETE" + action_permission = "ALLOW" + action_source_ip = compromised_node + # Randomly select a destination ID to block + action_destination_ip = np.random.choice( + list(range(1, number_of_nodes + 1)) + ["ANY"]) + action_destination_ip = int( + action_destination_ip) if action_destination_ip != "ANY" else action_destination_ip + action_protocol = service_num + 1 # +1 as 0 is any + # Randomly select a port to block + # Bad assumption that number of protocols equals number of ports AND no rules exist with an ANY port + action_port = np.random.choice(list(range(1, len(s) + 1))) + + action = [ + action_decision, + action_permission, + action_source_ip, + action_destination_ip, + action_protocol, + action_port, + ] + action = transform_action_acl_enum(action) + action = get_new_action(action, action_dict) + # We can only perform 1 action on each step + return action + + # If no good/useful action has been found, just perform a nothing action + nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"] + nothing_action = transform_action_acl_enum(nothing_action) + nothing_action = get_new_action(nothing_action, action_dict) + return nothing_action diff --git a/src/primaite/agents/hardcoded_node.py b/src/primaite/agents/hardcoded_node.py new file mode 100644 index 00000000..6db43da6 --- /dev/null +++ b/src/primaite/agents/hardcoded_node.py @@ -0,0 +1,97 @@ +from primaite.agents.agent import HardCodedAgentSessionABC +from primaite.agents.utils import ( + get_new_action, + transform_change_obs_readable, +) +from primaite.agents.utils import ( + transform_action_node_enum, +) + + +class HardCodedNodeAgent(HardCodedAgentSessionABC): + def _calculate_action(self, obs): + """Given an observation calculate a good node-based action for the blue agent to take""" + action_dict = self._env.action_dict + r_obs = transform_change_obs_readable(obs) + _, o, os, *s = r_obs + + if len(r_obs) == 4: # only 1 service + s = [*s] + + # Check in order of most important states (order doesn't currently matter, but it probably should) + # First see if any OS states are compromised + for x, os_state in enumerate(os): + if os_state == "COMPROMISED": + action_node_id = x + 1 + action_node_property = "OS" + property_action = "PATCHING" + action_service_index = 0 # does nothing isn't relevant for os + action = [action_node_id, action_node_property, + property_action, action_service_index] + action = transform_action_node_enum(action) + action = get_new_action(action, action_dict) + # We can only perform 1 action on each step + return action + + # Next, see if any Services are compromised + # We fix the compromised state before overwhelemd state, + # If a compromised entry node is fixed before the overwhelmed state is triggered, instruction is ignored + for service_num, service in enumerate(s): + for x, service_state in enumerate(service): + if service_state == "COMPROMISED": + action_node_id = x + 1 + action_node_property = "SERVICE" + property_action = "PATCHING" + action_service_index = service_num + + action = [action_node_id, action_node_property, + property_action, action_service_index] + action = transform_action_node_enum(action) + action = get_new_action(action, action_dict) + # We can only perform 1 action on each step + return action + + # Next, See if any services are overwhelmed + # perhaps this should be fixed automatically when the compromised PCs issues are also resolved + # Currently there's no reason that an Overwhelmed state cannot be resolved before resolving the compromised PCs + + for service_num, service in enumerate(s): + for x, service_state in enumerate(service): + if service_state == "OVERWHELMED": + action_node_id = x + 1 + action_node_property = "SERVICE" + property_action = "PATCHING" + action_service_index = service_num + + action = [action_node_id, action_node_property, + property_action, action_service_index] + action = transform_action_node_enum(action) + action = get_new_action(action, action_dict) + # We can only perform 1 action on each step + return action + + # Finally, turn on any off nodes + for x, operating_state in enumerate(o): + if os_state == "OFF": + action_node_id = x + 1 + action_node_property = "OPERATING" + property_action = "ON" # Why reset it when we can just turn it on + action_service_index = 0 # does nothing isn't relevant for operating state + action = [action_node_id, action_node_property, + property_action, action_service_index] + action = transform_action_node_enum(action, action_dict) + action = get_new_action(action, action_dict) + # We can only perform 1 action on each step + return action + + # If no good actions, just go with an action that wont do any harm + action_node_id = 1 + action_node_property = "NONE" + property_action = "ON" + action_service_index = 0 + action = [action_node_id, action_node_property, property_action, + action_service_index] + action = transform_action_node_enum(action) + action = get_new_action(action, action_dict) + + return action diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index 67ba6213..7d0cde60 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -1,21 +1,19 @@ import json from datetime import datetime from pathlib import Path -from pathlib import Path from typing import Optional from ray.rllib.algorithms import Algorithm -from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.algorithms.a2c import A2CConfig +from ray.rllib.algorithms.ppo import PPOConfig from ray.tune.logger import UnifiedLogger from ray.tune.registry import register_env from primaite import getLogger from primaite.agents.agent import AgentSessionABC -from primaite.common.enums import AgentFramework, RedAgentIdentifier +from primaite.common.enums import AgentFramework, AgentIdentifier from primaite.environment.primaite_env import Primaite - _LOGGER = getLogger(__name__) def _env_creator(env_config): @@ -51,13 +49,13 @@ class RLlibAgent(AgentSessionABC): f"got {self._training_config.agent_framework}") _LOGGER.error(msg) raise ValueError(msg) - if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO: + if self._training_config.agent_identifier == AgentIdentifier.PPO: self._agent_config_class = PPOConfig - elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C: + elif self._training_config.agent_identifier == AgentIdentifier.A2C: self._agent_config_class = A2CConfig else: - msg = ("Expected PPO or A2C red_agent_identifier, " - f"got {self._training_config.red_agent_identifier.value}") + msg = ("Expected PPO or A2C agent_identifier, " + f"got {self._training_config.agent_identifier.value}") _LOGGER.error(msg) raise ValueError(msg) self._agent_config: PPOConfig @@ -67,8 +65,8 @@ class RLlibAgent(AgentSessionABC): _LOGGER.debug( f"Created {self.__class__.__name__} using: " f"agent_framework={self._training_config.agent_framework}, " - f"red_agent_identifier=" - f"{self._training_config.red_agent_identifier}, " + f"agent_identifier=" + f"{self._training_config.agent_identifier}, " f"deep_learning_framework=" f"{self._training_config.deep_learning_framework}" ) @@ -117,7 +115,7 @@ class RLlibAgent(AgentSessionABC): train_batch_size=self._training_config.num_steps ) self._agent_config.framework( - framework=self._training_config.deep_learning_framework + framework="torch" ) self._agent_config.rollouts( diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 3cd2e50a..3748b57d 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -2,12 +2,12 @@ from typing import Optional import numpy as np from stable_baselines3 import PPO, A2C +from stable_baselines3.ppo import MlpPolicy as PPOMlp from primaite import getLogger from primaite.agents.agent import AgentSessionABC -from primaite.common.enums import RedAgentIdentifier, AgentFramework +from primaite.common.enums import AgentIdentifier, AgentFramework from primaite.environment.primaite_env import Primaite -from stable_baselines3.ppo import MlpPolicy as PPOMlp _LOGGER = getLogger(__name__) @@ -24,13 +24,13 @@ class SB3Agent(AgentSessionABC): f"got {self._training_config.agent_framework}") _LOGGER.error(msg) raise ValueError(msg) - if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO: + if self._training_config.agent_identifier == AgentIdentifier.PPO: self._agent_class = PPO - elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C: + elif self._training_config.agent_identifier == AgentIdentifier.A2C: self._agent_class = A2C else: - msg = ("Expected PPO or A2C red_agent_identifier, " - f"got {self._training_config.red_agent_identifier.value}") + msg = ("Expected PPO or A2C agent_identifier, " + f"got {self._training_config.agent_identifier.value}") _LOGGER.error(msg) raise ValueError(msg) @@ -40,8 +40,8 @@ class SB3Agent(AgentSessionABC): _LOGGER.debug( f"Created {self.__class__.__name__} using: " f"agent_framework={self._training_config.agent_framework}, " - f"red_agent_identifier=" - f"{self._training_config.red_agent_identifier}" + f"agent_identifier=" + f"{self._training_config.agent_identifier}" ) def _setup(self): @@ -56,7 +56,7 @@ class SB3Agent(AgentSessionABC): self._agent = self._agent_class( PPOMlp, self._env, - verbose=self._training_config.output_verbose_level, + verbose=self.output_verbose_level, n_steps=self._training_config.num_steps, tensorboard_log=self._tensorboard_log_path ) @@ -118,6 +118,7 @@ class SB3Agent(AgentSessionABC): action = np.int64(action) obs, rewards, done, info = self._env.step(action) + @classmethod def load(self): raise NotImplementedError diff --git a/src/primaite/agents/simple.py b/src/primaite/agents/simple.py new file mode 100644 index 00000000..cf333b1e --- /dev/null +++ b/src/primaite/agents/simple.py @@ -0,0 +1,60 @@ +from primaite.agents.agent import HardCodedAgentSessionABC +from primaite.agents.utils import ( + get_new_action, + transform_action_acl_enum, + transform_action_node_enum, +) + + +class RandomAgent(HardCodedAgentSessionABC): + """ + A Random Agent. + + Get a completely random action from the action space. + """ + + def _calculate_action(self, obs): + return self._env.action_space.sample() + + +class DummyAgent(HardCodedAgentSessionABC): + """ + A Dummy Agent. + + All action spaces setup so dummy action is always 0 regardless of action + type used. + """ + + def _calculate_action(self, obs): + return 0 + + +class DoNothingACLAgent(HardCodedAgentSessionABC): + """ + A do nothing ACL agent. + + A valid ACL action that has no effect; does nothing. + """ + + def _calculate_action(self, obs): + nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"] + nothing_action = transform_action_acl_enum(nothing_action) + nothing_action = get_new_action(nothing_action, self._env.action_dict) + + return nothing_action + + +class DoNothingNodeAgent(HardCodedAgentSessionABC): + """ + A do nothing Node agent. + + A valid Node action that has no effect; does nothing. + """ + + def _calculate_action(self, obs): + nothing_action = [1, "NONE", "ON", 0] + nothing_action = transform_action_node_enum(nothing_action) + nothing_action = get_new_action(nothing_action, self._env.action_dict) + # nothing_action should currently always be 0 + + return nothing_action diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py index bb967906..acc71590 100644 --- a/src/primaite/agents/utils.py +++ b/src/primaite/agents/utils.py @@ -1,4 +1,13 @@ -from primaite.common.enums import NodeHardwareAction, NodePOLType, NodeSoftwareAction +import numpy as np + +from primaite.common.enums import ( + HardwareState, + LinkStatus, + NodeHardwareAction, + NodeSoftwareAction, + SoftwareState, +) +from primaite.common.enums import NodePOLType def transform_action_node_readable(action): @@ -125,3 +134,393 @@ def is_valid_acl_action_extra(action): return False return True + + + +def transform_change_obs_readable(obs): + """Transform list of transactions to readable list of each observation property + + example: + + np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']] + """ + ids = [i for i in obs[:, 0]] + operating_states = [HardwareState(i).name for i in obs[:, 1]] + os_states = [SoftwareState(i).name for i in obs[:, 2]] + new_obs = [ids, operating_states, os_states] + + for service in range(3, obs.shape[1]): + # Links bit/s don't have a service state + service_states = [SoftwareState(i).name if i <= 4 else i for i in obs[:, service]] + new_obs.append(service_states) + + return new_obs + + +def transform_obs_readable(obs): + """ + example: + np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']] + """ + + changed_obs = transform_change_obs_readable(obs) + new_obs = list(zip(*changed_obs)) + # Convert list of tuples to list of lists + new_obs = [list(i) for i in new_obs] + + return new_obs + + +def convert_to_new_obs(obs, num_nodes=10): + """Convert original gym Box observation space to new multiDiscrete observation space""" + # Remove ID columns, remove links and flatten to MultiDiscrete observation space + new_obs = obs[:num_nodes, 1:].flatten() + return new_obs + + +def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): + """ + Convert to old observation, links filled with 0's as no information is included in new observation space + + example: + obs = array([1, 1, 1, 1, 1, 1, 1, 1, 1, ..., 1, 1, 1]) + + new_obs = array([[ 1, 1, 1, 1], + [ 2, 1, 1, 1], + [ 3, 1, 1, 1], + ... + [20, 0, 0, 0]]) + """ + + # Convert back to more readable, original format + reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2) + + # Add empty links back and add node ID back + s = np.zeros([reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1], dtype=np.int64) + s[:, 0] = range(1, num_nodes + num_links + 1) # Adding ID back + s[:num_nodes, 1:] = reshaped_nodes # put values back in + new_obs = s + + # Add links back in + links = obs[-num_links:] + # Links will be added to the last protocol/service slot but they are not specific to that service + new_obs[num_nodes:, -1] = links + + return new_obs + + +def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1): + """Return string describing change between two observations + + example: + obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]]) + obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]]) + output = 'ID 1: SERVICE 2 set to GOOD' + + """ + obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services) + obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services) + list_of_changes = [] + for n, row in enumerate(obs1 - obs2): + if row.any() != 0: + relevant_changes = np.where(row != 0, obs2[n], -1) + relevant_changes[0] = obs2[n, 0] # ID is always relevant + is_link = relevant_changes[0] > num_nodes + desc = _describe_obs_change_helper(relevant_changes, is_link) + list_of_changes.append(desc) + + change_string = "\n ".join(list_of_changes) + if len(list_of_changes) > 0: + change_string = "\n " + change_string + return change_string + + +def _describe_obs_change_helper(obs_change, is_link): + """ " + Helper funcion to describe what has changed + + example: + [ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD" + + Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.' + + """ + # Indexes where a change has occured, not including 0th index + index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1] + # Node pol types, Indexes >= 3 are service nodes + NodePOLTypes = [ + NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) for i in index_changed + ] + # Account for hardware states, software sattes and links + states = [ + LinkStatus(obs_change[i]).name + if is_link + else HardwareState(obs_change[i]).name + if i == 1 + else SoftwareState(obs_change[i]).name + for i in index_changed + ] + + if not is_link: + desc = f"ID {obs_change[0]}:" + for NodePOLType, state in list(zip(NodePOLTypes, states)): + desc = desc + " " + NodePOLType + " changed to " + state + "." + else: + desc = f"ID {obs_change[0]}: Link traffic changed to {states[0]}." + + return desc + + +def transform_action_node_enum(action): + """ + Convert a node action from readable string format, to enumerated format + + example: + [1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0] + """ + + action_node_id = action[0] + action_node_property = NodePOLType[action[1]].value + + if action[1] == "OPERATING": + property_action = NodeHardwareAction[action[2]].value + elif action[1] == "OS" or action[1] == "SERVICE": + property_action = NodeSoftwareAction[action[2]].value + else: + property_action = 0 + + action_service_index = action[3] + + new_action = [action_node_id, action_node_property, property_action, action_service_index] + + return new_action + + +def transform_action_node_readable(action): + """ + Convert a node action from enumerated format to readable format + + example: + [1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0] + """ + + action_node_property = NodePOLType(action[1]).name + + if action_node_property == "OPERATING": + property_action = NodeHardwareAction(action[2]).name + elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1: + property_action = NodeSoftwareAction(action[2]).name + else: + property_action = "NONE" + + new_action = [action[0], action_node_property, property_action, action[3]] + return new_action + + +def node_action_description(action): + """ + Generate string describing a node-based action + """ + + if isinstance(action[1], (int, np.int64)): + # transform action to readable format + action = transform_action_node_readable(action) + + node_id = action[0] + node_property = action[1] + property_action = action[2] + service_id = action[3] + + if property_action == "NONE": + return "" + if node_property == "OPERATING" or node_property == "OS": + description = f"NODE {node_id}, {node_property}, SET TO {property_action}" + elif node_property == "SERVICE": + description = f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}" + else: + return "" + + return description + + +def transform_action_acl_readable(action): + """ + Transform an ACL action to a more readable format + + example: + [0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1] + """ + + action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"} + action_permissions = {0: "DENY", 1: "ALLOW"} + + action_decision = action_decisions[action[0]] + action_permission = action_permissions[action[1]] + + # For IPs, Ports and Protocols, 0 means any, otherwise its just an index + new_action = [action_decision, action_permission] + list(action[2:6]) + for n, val in enumerate(list(action[2:6])): + if val == 0: + new_action[n + 2] = "ANY" + + return new_action + + +def transform_action_acl_enum(action): + """ + Convert a acl action from readable string format, to enumerated format + """ + + action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2} + action_permissions = {"DENY": 0, "ALLOW": 1} + + action_decision = action_decisions[action[0]] + action_permission = action_permissions[action[1]] + + # For IPs, Ports and Protocols, ANY has value 0, otherwise its just an index + new_action = [action_decision, action_permission] + list(action[2:6]) + for n, val in enumerate(list(action[2:6])): + if val == "ANY": + new_action[n + 2] = 0 + + new_action = np.array(new_action) + return new_action + + +def acl_action_description(action): + """generate string describing a acl-based action""" + + if isinstance(action[0], (int, np.int64)): + # transform action to readable format + action = transform_action_acl_readable(action) + if action[0] == "NONE": + description = "NO ACL RULE APPLIED" + else: + description = ( + f"{action[0]} RULE: {action[1]} traffic from IP {action[2]} to IP {action[3]}," + f" for protocol/service index {action[4]} on port index {action[5]}" + ) + + return description + + +def get_node_of_ip(ip, node_dict): + """ + Get the node ID of an IP address + + node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes) + """ + + for node_key, node_value in node_dict.items(): + node_ip = node_value.ip_address + if node_ip == ip: + return node_key + + +def is_valid_node_action(action): + """Is the node action an actual valid action + + Only uses information about the action to determine if the action has an effect + + Does NOT consider: + - Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch + - Node already being in that state (turning an ON node ON) + """ + action_r = transform_action_node_readable(action) + + node_property = action_r[1] + node_action = action_r[2] + + if node_property == "NONE": + return False + if node_action == "NONE": + return False + if node_property == "OPERATING" and node_action == "PATCHING": + # Operating State cannot PATCH + return False + if node_property != "OPERATING" and node_action not in ["NONE", "PATCHING"]: + # Software States can only do Nothing or Patch + return False + return True + + +def is_valid_acl_action(action): + """ + Is the ACL action an actual valid action + + Only uses information about the action to determine if the action has an effect + + Does NOT consider: + - Trying to create identical rules + - Trying to create a rule which is a subset of another rule (caused by "ANY") + """ + action_r = transform_action_acl_readable(action) + + action_decision = action_r[0] + action_permission = action_r[1] + action_source_id = action_r[2] + action_destination_id = action_r[3] + + if action_decision == "NONE": + return False + if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY": + # ACL rule towards itself + return False + if action_permission == "DENY": + # DENY is unnecessary, we can create and delete allow rules instead + # No allow rule = blocked/DENY by feault. ALLOW overrides existing DENY. + return False + + return True + + +def is_valid_acl_action_extra(action): + """Harsher version of valid acl actions, does not allow action""" + if is_valid_acl_action(action) is False: + return False + + action_r = transform_action_acl_readable(action) + action_protocol = action_r[4] + action_port = action_r[5] + + # Don't allow protocols or ports to be ANY + # in the future we might want to do the opposite, and only have ANY option for ports and service + if action_protocol == "ANY": + return False + if action_port == "ANY": + return False + + return True + + +def get_new_action(old_action, action_dict): + """Get new action (e.g. 32) from old action e.g. [1,1,1,0] + + old_action can be either node or acl action type + """ + + for key, val in action_dict.items(): + if list(val) == list(old_action): + return key + # Not all possible actions are included in dict, only valid action are + # if action is not in the dict, its an invalid action so return 0 + return 0 + + +def get_action_description(action, action_dict): + """ + Get a string describing/explaining what an action is doing in words + """ + + action_array = action_dict[action] + if len(action_array) == 4: + # node actions have length 4 + action_description = node_action_description(action_array) + elif len(action_array) == 6: + # acl actions have length 6 + action_description = acl_action_description(action_array) + else: + # Should never happen + action_description = "Unrecognised action" + + return action_description diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 89bfd737..191cb782 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -32,6 +32,7 @@ class Priority(Enum): class HardwareState(Enum): """Node hardware state enumeration.""" + NONE = 0 ON = 1 OFF = 2 RESETTING = 3 @@ -42,6 +43,7 @@ class HardwareState(Enum): class SoftwareState(Enum): """Software or Service state enumeration.""" + NONE = 0 GOOD = 1 PATCHING = 2 COMPROMISED = 3 @@ -94,7 +96,8 @@ class VerboseLevel(IntEnum): class AgentFramework(Enum): - NONE = 0 + """The agent algorithm framework/package.""" + CUSTOM = 0 "Custom Agent" SB3 = 1 "Stable Baselines3" @@ -103,7 +106,7 @@ class AgentFramework(Enum): class DeepLearningFramework(Enum): - """The deep learning framework enumeration.""" + """The deep learning framework.""" TF = "tf" "Tensorflow" TF2 = "tf2" @@ -112,15 +115,28 @@ class DeepLearningFramework(Enum): "PyTorch" -class RedAgentIdentifier(Enum): +class AgentIdentifier(Enum): + """The Red Agent algo/class.""" A2C = 1 "Advantage Actor Critic" PPO = 2 "Proximal Policy Optimization" HARDCODED = 3 - "Custom Agent" - RANDOM = 4 - "Custom Agent" + "The Hardcoded agents" + DO_NOTHING = 4 + "The DoNothing agents" + RANDOM = 5 + "The RandomAgent" + DUMMY = 6 + "The DummyAgent" + + +class HardCodedAgentView(Enum): + """The view the deterministic hard-coded agent has of the environment.""" + BASIC = 1 + "The current observation space only" + FULL = 2 + "Full environment view with actions taken and reward feedback" class ActionType(Enum): diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index d7b4db98..2cc29c55 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -1,32 +1,41 @@ -# Main Config File +# Training Config File -# Sets which agent algorithm framework will be used: +# Sets which agent algorithm framework will be used. # Options are: # "SB3" (Stable Baselines3) # "RLLIB" (Ray RLlib) -# "NONE" (Custom Agent) +# "CUSTOM" (Custom Agent) agent_framework: RLLIB -# Sets which deep learning framework will be used. Default is TF (Tensorflow). +# Sets which deep learning framework will be used (by RLlib ONLY). +# Default is TF (Tensorflow). # Options are: # "TF" (Tensorflow) # TF2 (Tensorflow 2.X) # TORCH (PyTorch) deep_learning_framework: TORCH -# Sets which Red Agent algo/class will be used: +# Sets which Agent class will be used. # Options are: -# "A2C" (Advantage Actor Critic) -# "PPO" (Proximal Policy Optimization) -# "HARDCODED" (Custom Agent) -# "RANDOM" (Random Action) -red_agent_identifier: PPO +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: PPO + +# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC. +# Options are: +# "BASIC" (The current observation space only) +# "FULL" (Full environment view with actions taken and reward feedback) +hard_coded_agent_view: FULL # Sets How the Action Space is defined: # "NODE" # "ACL" # "ANY" node and acl actions -action_type: NODE +action_type: ACL # Number of episodes to run per session num_episodes: 10 diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 4695f2f5..f8adae25 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -8,8 +8,8 @@ from typing import Any, Dict, Final, Union, Optional import yaml from primaite import USERS_CONFIG_DIR, getLogger -from primaite.common.enums import DeepLearningFramework -from primaite.common.enums import ActionType, RedAgentIdentifier, \ +from primaite.common.enums import DeepLearningFramework, HardCodedAgentView +from primaite.common.enums import ActionType, AgentIdentifier, \ AgentFramework, SessionType, OutputVerboseLevel _LOGGER = getLogger(__name__) @@ -42,8 +42,11 @@ class TrainingConfig: deep_learning_framework: DeepLearningFramework = DeepLearningFramework.TF "The DeepLearningFramework" - red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO - "The RedAgentIdentifier" + agent_identifier: AgentIdentifier = AgentIdentifier.PPO + "The AgentIdentifier" + + hard_coded_agent_view: HardCodedAgentView = HardCodedAgentView.FULL + "The view the deterministic hard-coded agent has of the environment" action_type: ActionType = ActionType.ANY "The ActionType to use" @@ -176,10 +179,11 @@ class TrainingConfig: field_enum_map = { "agent_framework": AgentFramework, "deep_learning_framework": DeepLearningFramework, - "red_agent_identifier": RedAgentIdentifier, + "agent_identifier": AgentIdentifier, "action_type": ActionType, "session_type": SessionType, - "output_verbose_level": OutputVerboseLevel + "output_verbose_level": OutputVerboseLevel, + "hard_coded_agent_view": HardCodedAgentView } for field, enum_class in field_enum_map.items(): @@ -197,12 +201,13 @@ class TrainingConfig: """ data = self.__dict__ if json_serializable: - data["agent_framework"] = self.agent_framework.value - data["deep_learning_framework"] = self.deep_learning_framework.value - data["red_agent_identifier"] = self.red_agent_identifier.value - data["action_type"] = self.action_type.value - data["output_verbose_level"] = self.output_verbose_level.value - data["session_type"] = self.session_type.value + data["agent_framework"] = self.agent_framework.name + data["deep_learning_framework"] = self.deep_learning_framework.name + data["agent_identifier"] = self.agent_identifier.name + data["action_type"] = self.action_type.name + data["output_verbose_level"] = self.output_verbose_level.name + data["session_type"] = self.session_type.name + data["hard_coded_agent_view"] = self.hard_coded_agent_view.name return data @@ -255,7 +260,7 @@ def load( def convert_legacy_training_config_dict( legacy_config_dict: Dict[str, Any], agent_framework: AgentFramework = AgentFramework.SB3, - red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO, + agent_identifier: AgentIdentifier = AgentIdentifier.PPO, action_type: ActionType = ActionType.ANY, num_steps: int = 256, output_verbose_level: OutputVerboseLevel = OutputVerboseLevel.INFO @@ -266,8 +271,8 @@ def convert_legacy_training_config_dict( :param legacy_config_dict: A legacy training config dict. :param agent_framework: The agent framework to use as legacy training configs don't have agent_framework values. - :param red_agent_identifier: The red agent identifier to use as legacy - training configs don't have red_agent_identifier values. + :param agent_identifier: The red agent identifier to use as legacy + training configs don't have agent_identifier values. :param action_type: The action space type to set as legacy training configs don't have action_type values. :param num_steps: The number of steps to set as legacy training configs @@ -278,7 +283,7 @@ def convert_legacy_training_config_dict( """ config_dict = { "agent_framework": agent_framework.name, - "red_agent_identifier": red_agent_identifier.name, + "agent_identifier": agent_identifier.name, "action_type": action_type.name, "num_steps": num_steps, "output_verbose_level": output_verbose_level diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 0876f070..502069ec 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -97,7 +97,7 @@ class Primaite(Env): self.transaction_list = transaction_list # The agent in use - self.agent_identifier = self.training_config.red_agent_identifier + self.agent_identifier = self.training_config.agent_identifier # Create a dictionary to hold all the nodes self.nodes: Dict[str, NodeUnion] = {} diff --git a/src/primaite/main.py b/src/primaite/main.py index 34134ba2..100248dd 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -1,137 +1,15 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. -""" -The main PrimAITE session runner module. - -TODO: This will eventually be refactored out into a proper Session class. -TODO: The passing about of session_path and timestamp_str is temporary and - will be cleaned up once we move to a proper Session class. -""" +"""The main PrimAITE session runner module.""" import argparse -import json -import time -from datetime import datetime from pathlib import Path -from typing import Final, Union -from uuid import uuid4 +from typing import Union -from stable_baselines3 import A2C, PPO -from stable_baselines3.common.evaluation import evaluate_policy -from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm -from stable_baselines3.ppo import MlpPolicy as PPOMlp - -from primaite import SESSIONS_DIR, getLogger -from primaite.config.training_config import TrainingConfig -from primaite.environment.primaite_env import Primaite +from primaite import getLogger from primaite.primaite_session import PrimaiteSession -from primaite.transactions.transactions_to_file import \ - write_transaction_to_file _LOGGER = getLogger(__name__) -def run_generic(env: Primaite, config_values: TrainingConfig): - """ - Run against a generic agent. - - :param env: An instance of - :class:`~primaite.environment.primaite_env.Primaite`. - :param config_values: An instance of - :class:`~primaite.config.training_config.TrainingConfig`. - """ - for episode in range(0, config_values.num_episodes): - env.reset() - for step in range(0, config_values.num_steps): - # Send the observation space to the agent to get an action - # TEMP - random action for now - # action = env.blue_agent_action(obs) - action = env.action_space.sample() - - # Run the simulation step on the live environment - obs, reward, done, info = env.step(action) - - # Break if done is True - if done: - break - - # Introduce a delay between steps - time.sleep(config_values.time_delay / 1000) - - # Reset the environment at the end of the episode - - env.close() - - -def run_stable_baselines3_ppo( - env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str -): - """ - Run against a stable_baselines3 PPO agent. - - :param env: An instance of - :class:`~primaite.environment.primaite_env.Primaite`. - :param config_values: An instance of - :class:`~primaite.config.training_config.TrainingConfig`. - :param session_path: The directory path the session is writing to. - :param timestamp_str: The session timestamp in the format: - _. - """ - if config_values.load_agent: - try: - agent = PPO.load( - config_values.agent_load_file, - env, - verbose=0, - n_steps=config_values.num_steps, - ) - except Exception: - print( - "ERROR: Could not load agent at location: " - + config_values.agent_load_file - ) - _LOGGER.error("Could not load agent") - _LOGGER.error("Exception occured", exc_info=True) - else: - agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps) - - if config_values.session_type == "TRAINING": - # We're in a training session - print("Starting training session...") - _LOGGER.debug("Starting training session...") - for episode in range(config_values.num_episodes): - agent.learn(total_timesteps=config_values.num_steps) - _save_agent(agent, session_path, timestamp_str) - else: - # Default to being in an evaluation session - print("Starting evaluation session...") - _LOGGER.debug("Starting evaluation session...") - evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes) - - env.close() - - - - -def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str): - """ - Persist an agent. - - Only works for stable baselines3 agents at present. - - :param session_path: The directory path the session is writing to. - :param timestamp_str: The session timestamp in the format: - _. - """ - if not isinstance(agent, OnPolicyAlgorithm): - msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}." - _LOGGER.error(msg) - else: - filepath = session_path / f"agent_saved_{timestamp_str}" - agent.save(filepath) - _LOGGER.debug(f"Trained agent saved as: {filepath}") - - - - def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]): """Run the PrimAITE Session. diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py index a4148d12..70a18a4b 100644 --- a/src/primaite/primaite_session.py +++ b/src/primaite/primaite_session.py @@ -8,9 +8,13 @@ from uuid import uuid4 from primaite import getLogger, SESSIONS_DIR from primaite.agents.agent import AgentSessionABC +from primaite.agents.hardcoded_acl import HardCodedACLAgent +from primaite.agents.hardcoded_node import HardCodedNodeAgent from primaite.agents.rllib import RLlibAgent from primaite.agents.sb3 import SB3Agent -from primaite.common.enums import AgentFramework, RedAgentIdentifier, \ +from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, \ + RandomAgent, DummyAgent +from primaite.common.enums import AgentFramework, AgentIdentifier, \ ActionType, SessionType from primaite.config import lay_down_config, training_config from primaite.config.training_config import TrainingConfig @@ -68,31 +72,66 @@ class PrimaiteSession: self.learn() def setup(self): - if self._training_config.agent_framework == AgentFramework.NONE: - if self._training_config.red_agent_identifier == RedAgentIdentifier.RANDOM: - # Stochastic Random Agent - raise NotImplementedError - - elif self._training_config.red_agent_identifier == RedAgentIdentifier.HARDCODED: + if self._training_config.agent_framework == AgentFramework.CUSTOM: + if self._training_config.agent_identifier == AgentIdentifier.HARDCODED: if self._training_config.action_type == ActionType.NODE: # Deterministic Hardcoded Agent with Node Action Space - raise NotImplementedError + self._agent_session = HardCodedNodeAgent( + self._training_config_path, + self._lay_down_config_path + ) elif self._training_config.action_type == ActionType.ACL: # Deterministic Hardcoded Agent with ACL Action Space - raise NotImplementedError + self._agent_session = HardCodedACLAgent( + self._training_config_path, + self._lay_down_config_path + ) elif self._training_config.action_type == ActionType.ANY: # Deterministic Hardcoded Agent with ANY Action Space raise NotImplementedError else: - # Invalid RedAgentIdentifier ActionType combo - pass + # Invalid AgentIdentifier ActionType combo + raise ValueError + + elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING: + if self._training_config.action_type == ActionType.NODE: + self._agent_session = DoNothingNodeAgent( + self._training_config_path, + self._lay_down_config_path + ) + + elif self._training_config.action_type == ActionType.ACL: + # Deterministic Hardcoded Agent with ACL Action Space + self._agent_session = DoNothingACLAgent( + self._training_config_path, + self._lay_down_config_path + ) + + elif self._training_config.action_type == ActionType.ANY: + # Deterministic Hardcoded Agent with ANY Action Space + raise NotImplementedError + + else: + # Invalid AgentIdentifier ActionType combo + raise ValueError + + elif self._training_config.agent_identifier == AgentIdentifier.RANDOM: + self._agent_session = RandomAgent( + self._training_config_path, + self._lay_down_config_path + ) + elif self._training_config.agent_identifier == AgentIdentifier.DUMMY: + self._agent_session = DummyAgent( + self._training_config_path, + self._lay_down_config_path + ) else: - # Invalid AgentFramework RedAgentIdentifier combo - pass + # Invalid AgentFramework AgentIdentifier combo + raise ValueError elif self._training_config.agent_framework == AgentFramework.SB3: # Stable Baselines3 Agent @@ -110,7 +149,7 @@ class PrimaiteSession: else: # Invalid AgentFramework - pass + raise ValueError def learn( self, diff --git a/tests/config/legacy/new_training_config.yaml b/tests/config/legacy/new_training_config.yaml index 44897bfa..9fdf9a05 100644 --- a/tests/config/legacy/new_training_config.yaml +++ b/tests/config/legacy/new_training_config.yaml @@ -13,7 +13,7 @@ agent_framework: RLLIB # "A2C" (Advantage Actor Critic) # "HARDCODED" (Custom Agent) # "RANDOM" (Random Action) -red_agent_identifier: PPO +agent_identifier: PPO # Sets How the Action Space is defined: # "NODE"