@@ -165,12 +165,10 @@ class Primaite(Env):
|
||||
# Number of ports - gets a value when config is loaded
|
||||
self.num_ports = 0
|
||||
|
||||
# The action type
|
||||
self.action_type = 0
|
||||
|
||||
# Observation type, by default box.
|
||||
self.observation_type = ObservationType.BOX
|
||||
|
||||
|
||||
# Open the config file and build the environment laydown
|
||||
with open(self._lay_down_config_path, "r") as file:
|
||||
# Open the config file and build the environment laydown
|
||||
@@ -206,7 +204,7 @@ class Primaite(Env):
|
||||
self.observation_space, self.env_obs = self.init_observations()
|
||||
|
||||
# Define Action Space - depends on action space type (Node or ACL)
|
||||
if self.action_type == ActionType.NODE:
|
||||
if self.training_config.action_type == ActionType.NODE:
|
||||
_LOGGER.info("Action space type NODE selected")
|
||||
# Terms (for node action space):
|
||||
# [0, num nodes] - node ID (0 = nothing, node ID)
|
||||
@@ -215,7 +213,7 @@ class Primaite(Env):
|
||||
# [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa
|
||||
self.action_dict = self.create_node_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||
elif self.action_type == ActionType.ACL:
|
||||
elif self.training_config.action_type == ActionType.ACL:
|
||||
_LOGGER.info("Action space type ACL selected")
|
||||
# Terms (for ACL action space):
|
||||
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
|
||||
@@ -226,12 +224,12 @@ class Primaite(Env):
|
||||
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
|
||||
self.action_dict = self.create_acl_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||
elif self.action_type == ActionType.ANY:
|
||||
elif self.training_config.action_type == ActionType.ANY:
|
||||
_LOGGER.info("Action space type ANY selected - Node + ACL")
|
||||
self.action_dict = self.create_node_and_acl_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||
else:
|
||||
_LOGGER.info("Invalid action type selected")
|
||||
_LOGGER.info(f"Invalid action type selected")
|
||||
# Set up a csv to store the results of the training
|
||||
try:
|
||||
header = ["Episode", "Average Reward"]
|
||||
@@ -426,9 +424,12 @@ class Primaite(Env):
|
||||
_action: The action space from the agent
|
||||
"""
|
||||
# At the moment, actions are only affecting nodes
|
||||
if self.action_type == ActionType.NODE:
|
||||
print("")
|
||||
print(_action)
|
||||
print(self.action_dict)
|
||||
if self.training_config.action_type == ActionType.NODE:
|
||||
self.apply_actions_to_nodes(_action)
|
||||
elif self.action_type == ActionType.ACL:
|
||||
elif self.training_config.action_type == ActionType.ACL:
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif (
|
||||
len(self.action_dict[_action]) == 6
|
||||
@@ -901,34 +902,34 @@ class Primaite(Env):
|
||||
def load_lay_down_config(self):
|
||||
"""Loads config data in order to build the environment configuration."""
|
||||
for item in self.lay_down_config:
|
||||
if item["itemType"] == "NODE":
|
||||
if item["item_type"] == "NODE":
|
||||
# Create a node
|
||||
self.create_node(item)
|
||||
elif item["itemType"] == "LINK":
|
||||
elif item["item_type"] == "LINK":
|
||||
# Create a link
|
||||
self.create_link(item)
|
||||
elif item["itemType"] == "GREEN_IER":
|
||||
elif item["item_type"] == "GREEN_IER":
|
||||
# Create a Green IER
|
||||
self.create_green_ier(item)
|
||||
elif item["itemType"] == "GREEN_POL":
|
||||
elif item["item_type"] == "GREEN_POL":
|
||||
# Create a Green PoL
|
||||
self.create_green_pol(item)
|
||||
elif item["itemType"] == "RED_IER":
|
||||
elif item["item_type"] == "RED_IER":
|
||||
# Create a Red IER
|
||||
self.create_red_ier(item)
|
||||
elif item["itemType"] == "RED_POL":
|
||||
elif item["item_type"] == "RED_POL":
|
||||
# Create a Red PoL
|
||||
self.create_red_pol(item)
|
||||
elif item["itemType"] == "ACL_RULE":
|
||||
elif item["item_type"] == "ACL_RULE":
|
||||
# Create an ACL rule
|
||||
self.create_acl_rule(item)
|
||||
elif item["itemType"] == "SERVICES":
|
||||
elif item["item_type"] == "SERVICES":
|
||||
# Create the list of services
|
||||
self.create_services_list(item)
|
||||
elif item["itemType"] == "PORTS":
|
||||
elif item["item_type"] == "PORTS":
|
||||
# Create the list of ports
|
||||
self.create_ports_list(item)
|
||||
elif item["itemType"] == "OBSERVATIONS":
|
||||
elif item["item_type"] == "OBSERVATIONS":
|
||||
# Get the observation information
|
||||
self.get_observation_info(item)
|
||||
else:
|
||||
@@ -1071,14 +1072,14 @@ class Primaite(Env):
|
||||
item: A config data item
|
||||
"""
|
||||
ier_id = item["id"]
|
||||
ier_start_step = item["startStep"]
|
||||
ier_end_step = item["endStep"]
|
||||
ier_start_step = item["start_step"]
|
||||
ier_end_step = item["end_step"]
|
||||
ier_load = item["load"]
|
||||
ier_protocol = item["protocol"]
|
||||
ier_port = item["port"]
|
||||
ier_source = item["source"]
|
||||
ier_destination = item["destination"]
|
||||
ier_mission_criticality = item["missionCriticality"]
|
||||
ier_mission_criticality = item["mission_criticality"]
|
||||
|
||||
# Create IER and add to green IER dictionary
|
||||
self.green_iers[ier_id] = IER(
|
||||
@@ -1101,14 +1102,14 @@ class Primaite(Env):
|
||||
item: A config data item
|
||||
"""
|
||||
ier_id = item["id"]
|
||||
ier_start_step = item["startStep"]
|
||||
ier_end_step = item["endStep"]
|
||||
ier_start_step = item["start_step"]
|
||||
ier_end_step = item["end_step"]
|
||||
ier_load = item["load"]
|
||||
ier_protocol = item["protocol"]
|
||||
ier_port = item["port"]
|
||||
ier_source = item["source"]
|
||||
ier_destination = item["destination"]
|
||||
ier_mission_criticality = item["missionCriticality"]
|
||||
ier_mission_criticality = item["mission_criticality"]
|
||||
|
||||
# Create IER and add to red IER dictionary
|
||||
self.red_iers[ier_id] = IER(
|
||||
@@ -1131,8 +1132,8 @@ class Primaite(Env):
|
||||
item: A config data item
|
||||
"""
|
||||
pol_id = item["id"]
|
||||
pol_start_step = item["startStep"]
|
||||
pol_end_step = item["endStep"]
|
||||
pol_start_step = item["start_step"]
|
||||
pol_end_step = item["end_step"]
|
||||
pol_node = item["nodeId"]
|
||||
pol_type = NodePOLType[item["type"]]
|
||||
|
||||
@@ -1165,8 +1166,8 @@ class Primaite(Env):
|
||||
item: A config data item
|
||||
"""
|
||||
pol_id = item["id"]
|
||||
pol_start_step = item["startStep"]
|
||||
pol_end_step = item["endStep"]
|
||||
pol_start_step = item["start_step"]
|
||||
pol_end_step = item["end_step"]
|
||||
pol_target_node_id = item["targetNodeId"]
|
||||
pol_initiator = NodePOLInitiator[item["initiator"]]
|
||||
pol_type = NodePOLType[item["type"]]
|
||||
@@ -1226,7 +1227,7 @@ class Primaite(Env):
|
||||
Args:
|
||||
item: A config data item representing the services
|
||||
"""
|
||||
service_list = services["serviceList"]
|
||||
service_list = services["service_list"]
|
||||
|
||||
for service in service_list:
|
||||
service_name = service["name"]
|
||||
@@ -1242,7 +1243,7 @@ class Primaite(Env):
|
||||
Args:
|
||||
item: A config data item representing the ports
|
||||
"""
|
||||
ports_list = ports["portsList"]
|
||||
ports_list = ports["ports_list"]
|
||||
|
||||
for port in ports_list:
|
||||
port_value = port["port"]
|
||||
@@ -1267,10 +1268,10 @@ class Primaite(Env):
|
||||
configuration.
|
||||
"""
|
||||
for item in self.lay_down_config:
|
||||
if item["itemType"] == "NODE":
|
||||
if item["item_type"] == "NODE":
|
||||
# Reset a node's state (normal and reference)
|
||||
self.reset_node(item)
|
||||
elif item["itemType"] == "ACL_RULE":
|
||||
elif item["item_type"] == "ACL_RULE":
|
||||
# Create an ACL rule (these are cleared on reset, so just need to recreate them)
|
||||
self.create_acl_rule(item)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user