Remove redundant cols from transactions
This commit is contained in:
@@ -168,7 +168,7 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
f"link_{link_id}_n/a",
|
||||
]
|
||||
for j, serv in enumerate(self.env.services_list):
|
||||
link_labels.append(f"node_{node_id}_service_{serv}_load")
|
||||
link_labels.append(f"link_{link_id}_service_{serv}_load")
|
||||
|
||||
structure.extend(link_labels)
|
||||
return structure
|
||||
|
||||
@@ -318,7 +318,7 @@ class Primaite(Env):
|
||||
datetime.now(), self.agent_identifier, self.episode_count, self.step_count
|
||||
)
|
||||
# Load the initial observation space into the transaction
|
||||
transaction.set_obs_space_pre(self.obs_handler._flat_observation)
|
||||
transaction.set_obs_space(self.obs_handler._flat_observation)
|
||||
|
||||
# Load the action space into the transaction
|
||||
transaction.set_action_space(copy.deepcopy(action))
|
||||
@@ -400,8 +400,6 @@ class Primaite(Env):
|
||||
|
||||
# 7. Update env_obs
|
||||
self.update_environent_obs()
|
||||
# Load the new observation space into the transaction
|
||||
transaction.set_obs_space_post(self.obs_handler._flat_observation)
|
||||
|
||||
# 8. Add the transaction to the list of transactions
|
||||
self.transaction_list.append(copy.deepcopy(transaction))
|
||||
|
||||
@@ -20,23 +20,14 @@ class Transaction(object):
|
||||
self.episode_number = _episode_number
|
||||
self.step_number = _step_number
|
||||
|
||||
def set_obs_space_pre(self, _obs_space_pre):
|
||||
def set_obs_space(self, _obs_space):
|
||||
"""
|
||||
Sets the observation space (pre).
|
||||
|
||||
Args:
|
||||
_obs_space_pre: The observation space before any actions are taken
|
||||
"""
|
||||
self.obs_space_pre = _obs_space_pre
|
||||
|
||||
def set_obs_space_post(self, _obs_space_post):
|
||||
"""
|
||||
Sets the observation space (post).
|
||||
|
||||
Args:
|
||||
_obs_space_post: The observation space after any actions are taken
|
||||
"""
|
||||
self.obs_space_post = _obs_space_post
|
||||
self.obs_space = _obs_space
|
||||
|
||||
def set_reward(self, _reward):
|
||||
"""
|
||||
|
||||
@@ -58,12 +58,12 @@ def write_transaction_to_file(
|
||||
action_header.append("AS_" + str(x))
|
||||
|
||||
# Create the observation space headers array
|
||||
obs_header_initial = [f"pre_{o}" for o in obs_space_description]
|
||||
obs_header_new = [f"post_{o}" for o in obs_space_description]
|
||||
# obs_header_initial = [f"pre_{o}" for o in obs_space_description]
|
||||
# obs_header_new = [f"post_{o}" for o in obs_space_description]
|
||||
|
||||
# Open up a csv file
|
||||
header = ["Timestamp", "Episode", "Step", "Reward"]
|
||||
header = header + action_header + obs_header_initial + obs_header_new
|
||||
header = header + action_header + obs_space_description
|
||||
|
||||
try:
|
||||
filename = session_path / f"all_transactions_{timestamp_str}.csv"
|
||||
@@ -82,8 +82,7 @@ def write_transaction_to_file(
|
||||
csv_data = (
|
||||
csv_data
|
||||
+ turn_action_space_to_array(transaction.action_space)
|
||||
+ transaction.obs_space_pre.tolist()
|
||||
+ transaction.obs_space_post.tolist()
|
||||
+ transaction.obs_space.tolist()
|
||||
)
|
||||
csv_writer.writerow(csv_data)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user