893 - set the action_space to NOTHING so test_reward.py passes and removed unnecessary test print statements

This commit is contained in:
SunilSamra
2023-06-06 11:10:38 +01:00
parent 1a7d629d5a
commit dc7be7d8e6
2 changed files with 3 additions and 22 deletions

View File

@@ -236,7 +236,6 @@ class Primaite(Env):
# [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa
self.action_dict = self.create_node_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
print(self.action_space, "NODE action space")
elif self.action_type == ActionType.ACL:
_LOGGER.info("Action space type ACL selected")
# Terms (for ACL action space):
@@ -248,12 +247,10 @@ class Primaite(Env):
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
self.action_dict = self.create_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
print(self.action_space, "ACL action space")
elif self.action_type == ActionType.ANY:
_LOGGER.info("Action space type ANY selected - Node + ACL")
self.action_dict = self.create_node_and_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
print(self.action_space, "ANY action space")
else:
_LOGGER.info("Invalid action type selected")
# Set up a csv to store the results of the training
@@ -455,7 +452,6 @@ class Primaite(Env):
Args:
_action: The action space from the agent
"""
# print("intepret action")
# At the moment, actions are only affecting nodes
if self.action_type == ActionType.NODE:
self.apply_actions_to_nodes(_action)
@@ -470,7 +466,6 @@ class Primaite(Env):
): # Node actions in multdiscrete (array) from have len 4
self.apply_actions_to_nodes(_action)
else:
print("invalid action type found")
logging.error("Invalid action type found")
def apply_actions_to_nodes(self, _action):
@@ -1091,7 +1086,6 @@ class Primaite(Env):
item: A config data item representing action info
"""
self.action_type = ActionType[action_info["type"]]
print("action type selected: ", self.action_type)
def get_steps_info(self, steps_info):
"""
@@ -1196,7 +1190,6 @@ class Primaite(Env):
"""
# reserve 0 action to be a nothing action
actions = {0: [1, 0, 0, 0]}
# print("node dict function call", self.num_nodes + 1)
action_key = 1
for node in range(1, self.num_nodes + 1):
# 4 node properties (NONE, OPERATING, OS, SERVICE)
@@ -1204,14 +1197,11 @@ class Primaite(Env):
# Node Actions either:
# (NONE, ON, OFF, RESET) - operating state OR (NONE, PATCH) - OS/service state
# Use MAX to ensure we get them all
# print(self.num_services, "num services")
for node_action in range(4):
for service_state in range(self.num_services):
action = [node, node_property, node_action, service_state]
# check to see if its a nothing aciton (has no effect)
# print("action node",action)
# check to see if it's a nothing action (has no effect)
if is_valid_node_action(action):
print("true")
actions[action_key] = action
action_key += 1
@@ -1223,7 +1213,6 @@ class Primaite(Env):
actions = {0: [0, 0, 0, 0, 0, 0]}
action_key = 1
# print("node count",self.num_nodes + 1)
# 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE
for action_decision in range(3):
# 2 possible action permissions 0 = DENY, 1 = CREATE
@@ -1241,13 +1230,10 @@ class Primaite(Env):
protocol,
port,
]
# print("action acl", action)
# Check to see if its an action we want to include as possible i.e. not a nothing action
if is_valid_acl_action_extra(action):
print("true")
actions[action_key] = action
action_key += 1
# print("false")
return actions
@@ -1261,8 +1247,6 @@ class Primaite(Env):
node_action_dict = self.create_node_action_dict()
acl_action_dict = self.create_acl_action_dict()
print(len(node_action_dict), len(acl_action_dict))
# Change node keys to not overlap with acl keys
# Only 1 nothing action (key 0) is required, remove the other
new_node_action_dict = {
@@ -1273,6 +1257,4 @@ class Primaite(Env):
# Combine the Node dict and ACL dict
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
print("combined_action_dict entry", combined_action_dict.items())
# print(len(list(combined_action_dict.values())))
return combined_action_dict

View File

@@ -184,8 +184,8 @@ def run_generic(env, config_values):
# Send the observation space to the agent to get an action
# TEMP - random action for now
# action = env.blue_agent_action(obs)
action = env.action_space.sample()
# action = env.action_space.sample()
action = 0
# Run the simulation step on the live environment
obs, reward, done, info = env.step(action)
@@ -222,7 +222,6 @@ def run_generic_set_actions(env, config_values):
# Sets Node 1 Hardware State to OFF
# Does not resolve any service
action = 16
print(action, "ran")
# Run the simulation step on the live environment
obs, reward, done, info = env.step(action)