Remove redundant cols from transactions

This commit is contained in:
Marek Wolan
2023-06-30 10:41:56 +01:00
parent c3c4512544
commit 2a8d28cba6
4 changed files with 8 additions and 20 deletions

View File

@@ -168,7 +168,7 @@ class NodeLinkTable(AbstractObservationComponent):
f"link_{link_id}_n/a",
]
for j, serv in enumerate(self.env.services_list):
link_labels.append(f"node_{node_id}_service_{serv}_load")
link_labels.append(f"link_{link_id}_service_{serv}_load")
structure.extend(link_labels)
return structure

View File

@@ -318,7 +318,7 @@ class Primaite(Env):
datetime.now(), self.agent_identifier, self.episode_count, self.step_count
)
# Load the initial observation space into the transaction
transaction.set_obs_space_pre(self.obs_handler._flat_observation)
transaction.set_obs_space(self.obs_handler._flat_observation)
# Load the action space into the transaction
transaction.set_action_space(copy.deepcopy(action))
@@ -400,8 +400,6 @@ class Primaite(Env):
# 7. Update env_obs
self.update_environent_obs()
# Load the new observation space into the transaction
transaction.set_obs_space_post(self.obs_handler._flat_observation)
# 8. Add the transaction to the list of transactions
self.transaction_list.append(copy.deepcopy(transaction))

View File

@@ -20,23 +20,14 @@ class Transaction(object):
self.episode_number = _episode_number
self.step_number = _step_number
def set_obs_space_pre(self, _obs_space_pre):
def set_obs_space(self, _obs_space):
"""
Sets the observation space (pre).
Args:
_obs_space_pre: The observation space before any actions are taken
"""
self.obs_space_pre = _obs_space_pre
def set_obs_space_post(self, _obs_space_post):
"""
Sets the observation space (post).
Args:
_obs_space_post: The observation space after any actions are taken
"""
self.obs_space_post = _obs_space_post
self.obs_space = _obs_space
def set_reward(self, _reward):
"""

View File

@@ -58,12 +58,12 @@ def write_transaction_to_file(
action_header.append("AS_" + str(x))
# Create the observation space headers array
obs_header_initial = [f"pre_{o}" for o in obs_space_description]
obs_header_new = [f"post_{o}" for o in obs_space_description]
# obs_header_initial = [f"pre_{o}" for o in obs_space_description]
# obs_header_new = [f"post_{o}" for o in obs_space_description]
# Open up a csv file
header = ["Timestamp", "Episode", "Step", "Reward"]
header = header + action_header + obs_header_initial + obs_header_new
header = header + action_header + obs_space_description
try:
filename = session_path / f"all_transactions_{timestamp_str}.csv"
@@ -82,8 +82,7 @@ def write_transaction_to_file(
csv_data = (
csv_data
+ turn_action_space_to_array(transaction.action_space)
+ transaction.obs_space_pre.tolist()
+ transaction.obs_space_post.tolist()
+ transaction.obs_space.tolist()
)
csv_writer.writerow(csv_data)