#915 - Synced with dev to bring in changes from #898

This commit is contained in:
Chris McCarthy
2023-06-09 13:11:14 +01:00
parent 9b4ed1199b
commit af4e71db9b
14 changed files with 510 additions and 504 deletions

View File

@@ -165,12 +165,10 @@ class Primaite(Env):
# Number of ports - gets a value when config is loaded
self.num_ports = 0
# The action type
self.action_type = 0
# Observation type, by default box.
self.observation_type = ObservationType.BOX
# Open the config file and build the environment laydown
with open(self._lay_down_config_path, "r") as file:
# Open the config file and build the environment laydown
@@ -206,7 +204,7 @@ class Primaite(Env):
self.observation_space, self.env_obs = self.init_observations()
# Define Action Space - depends on action space type (Node or ACL)
if self.action_type == ActionType.NODE:
if self.training_config.action_type == ActionType.NODE:
_LOGGER.info("Action space type NODE selected")
# Terms (for node action space):
# [0, num nodes] - node ID (0 = nothing, node ID)
@@ -215,7 +213,7 @@ class Primaite(Env):
# [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa
self.action_dict = self.create_node_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
elif self.action_type == ActionType.ACL:
elif self.training_config.action_type == ActionType.ACL:
_LOGGER.info("Action space type ACL selected")
# Terms (for ACL action space):
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
@@ -226,12 +224,12 @@ class Primaite(Env):
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
self.action_dict = self.create_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
elif self.action_type == ActionType.ANY:
elif self.training_config.action_type == ActionType.ANY:
_LOGGER.info("Action space type ANY selected - Node + ACL")
self.action_dict = self.create_node_and_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
else:
_LOGGER.info("Invalid action type selected")
_LOGGER.info(f"Invalid action type selected")
# Set up a csv to store the results of the training
try:
header = ["Episode", "Average Reward"]
@@ -426,9 +424,12 @@ class Primaite(Env):
_action: The action space from the agent
"""
# At the moment, actions are only affecting nodes
if self.action_type == ActionType.NODE:
print("")
print(_action)
print(self.action_dict)
if self.training_config.action_type == ActionType.NODE:
self.apply_actions_to_nodes(_action)
elif self.action_type == ActionType.ACL:
elif self.training_config.action_type == ActionType.ACL:
self.apply_actions_to_acl(_action)
elif (
len(self.action_dict[_action]) == 6
@@ -901,34 +902,34 @@ class Primaite(Env):
def load_lay_down_config(self):
"""Loads config data in order to build the environment configuration."""
for item in self.lay_down_config:
if item["itemType"] == "NODE":
if item["item_type"] == "NODE":
# Create a node
self.create_node(item)
elif item["itemType"] == "LINK":
elif item["item_type"] == "LINK":
# Create a link
self.create_link(item)
elif item["itemType"] == "GREEN_IER":
elif item["item_type"] == "GREEN_IER":
# Create a Green IER
self.create_green_ier(item)
elif item["itemType"] == "GREEN_POL":
elif item["item_type"] == "GREEN_POL":
# Create a Green PoL
self.create_green_pol(item)
elif item["itemType"] == "RED_IER":
elif item["item_type"] == "RED_IER":
# Create a Red IER
self.create_red_ier(item)
elif item["itemType"] == "RED_POL":
elif item["item_type"] == "RED_POL":
# Create a Red PoL
self.create_red_pol(item)
elif item["itemType"] == "ACL_RULE":
elif item["item_type"] == "ACL_RULE":
# Create an ACL rule
self.create_acl_rule(item)
elif item["itemType"] == "SERVICES":
elif item["item_type"] == "SERVICES":
# Create the list of services
self.create_services_list(item)
elif item["itemType"] == "PORTS":
elif item["item_type"] == "PORTS":
# Create the list of ports
self.create_ports_list(item)
elif item["itemType"] == "OBSERVATIONS":
elif item["item_type"] == "OBSERVATIONS":
# Get the observation information
self.get_observation_info(item)
else:
@@ -1071,14 +1072,14 @@ class Primaite(Env):
item: A config data item
"""
ier_id = item["id"]
ier_start_step = item["startStep"]
ier_end_step = item["endStep"]
ier_start_step = item["start_step"]
ier_end_step = item["end_step"]
ier_load = item["load"]
ier_protocol = item["protocol"]
ier_port = item["port"]
ier_source = item["source"]
ier_destination = item["destination"]
ier_mission_criticality = item["missionCriticality"]
ier_mission_criticality = item["mission_criticality"]
# Create IER and add to green IER dictionary
self.green_iers[ier_id] = IER(
@@ -1101,14 +1102,14 @@ class Primaite(Env):
item: A config data item
"""
ier_id = item["id"]
ier_start_step = item["startStep"]
ier_end_step = item["endStep"]
ier_start_step = item["start_step"]
ier_end_step = item["end_step"]
ier_load = item["load"]
ier_protocol = item["protocol"]
ier_port = item["port"]
ier_source = item["source"]
ier_destination = item["destination"]
ier_mission_criticality = item["missionCriticality"]
ier_mission_criticality = item["mission_criticality"]
# Create IER and add to red IER dictionary
self.red_iers[ier_id] = IER(
@@ -1131,8 +1132,8 @@ class Primaite(Env):
item: A config data item
"""
pol_id = item["id"]
pol_start_step = item["startStep"]
pol_end_step = item["endStep"]
pol_start_step = item["start_step"]
pol_end_step = item["end_step"]
pol_node = item["nodeId"]
pol_type = NodePOLType[item["type"]]
@@ -1165,8 +1166,8 @@ class Primaite(Env):
item: A config data item
"""
pol_id = item["id"]
pol_start_step = item["startStep"]
pol_end_step = item["endStep"]
pol_start_step = item["start_step"]
pol_end_step = item["end_step"]
pol_target_node_id = item["targetNodeId"]
pol_initiator = NodePOLInitiator[item["initiator"]]
pol_type = NodePOLType[item["type"]]
@@ -1226,7 +1227,7 @@ class Primaite(Env):
Args:
item: A config data item representing the services
"""
service_list = services["serviceList"]
service_list = services["service_list"]
for service in service_list:
service_name = service["name"]
@@ -1242,7 +1243,7 @@ class Primaite(Env):
Args:
item: A config data item representing the ports
"""
ports_list = ports["portsList"]
ports_list = ports["ports_list"]
for port in ports_list:
port_value = port["port"]
@@ -1267,10 +1268,10 @@ class Primaite(Env):
configuration.
"""
for item in self.lay_down_config:
if item["itemType"] == "NODE":
if item["item_type"] == "NODE":
# Reset a node's state (normal and reference)
self.reset_node(item)
elif item["itemType"] == "ACL_RULE":
elif item["item_type"] == "ACL_RULE":
# Create an ACL rule (these are cleared on reset, so just need to recreate them)
self.create_acl_rule(item)
else: