#2085 - Added step metadata json file dumps to the environments. Fixed serialization issues in the Switch and ACLRule classes.

This commit is contained in:
Chris McCarthy
2023-12-01 16:37:58 +00:00
parent 32c13e06f6
commit cc04efb31d
4 changed files with 41 additions and 16 deletions

View File

@@ -31,4 +31,6 @@ Outputs
Running a session creates a session output directory in your user data folder. The filepath looks like this:
``~/primaite/3.0.0/sessions/YYYY-MM-DD/HH-MM-SS/``. This folder contains the simulation sys logs generated by each node,
the saved agent checkpoints, and final model.
the saved agent checkpoints, and final model. The folder also contains a .json file for each episode step that
contains the action, reward, and simulation state. These can be found in
``~/primaite/3.0.0/sessions/YYYY-MM-DD/HH-MM-SS/simulation_output/episode_<n>/step_metadata/step_<n>.json``

View File

@@ -5,9 +5,9 @@ import gymnasium
from gymnasium.core import ActType, ObsType
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from primaite import PRIMAITE_PATHS
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator import SIM_OUTPUT
class PrimaiteGymEnv(gymnasium.Env):
@@ -33,17 +33,6 @@ class PrimaiteGymEnv(gymnasium.Env):
self.game.advance_timestep()
state = self.game.get_sim_state()
# Create state suitable for dumping to file.
# dump_state = {self.game.episode_counter: {self.game.step_counter: state}}
# Dump to file
# if os.path.isfile(PRIMAITE_PATHS.episode_steps_log_file_path):
with open(PRIMAITE_PATHS.episode_log_file_path, "a", encoding="utf-8") as f:
# f.write(str(dump_state))
# f.write("\n=================\n")
# f.flush()
json.dump(state, f)
self.game.update_agents(state)
next_obs = self._get_obs()
@@ -51,9 +40,26 @@ class PrimaiteGymEnv(gymnasium.Env):
terminated = False
truncated = self.game.calculate_truncated()
info = {}
self._write_step_metadata_json(action, state, reward)
print(f"Episode: {self.game.episode_counter}, Step: {self.game.step_counter}, Reward: {reward}")
return next_obs, reward, terminated, truncated, info
def _write_step_metadata_json(self, action: int, state: Dict, reward: int):
output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata"
output_dir.mkdir(parents=True, exist_ok=True)
path = output_dir / f"step_{self.game.step_counter}.json"
data = {
"episode": self.game.episode_counter,
"step": self.game.step_counter,
"action": int(action),
"reward": int(reward),
"state": state,
}
with open(path, "w") as file:
json.dump(data, file)
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
"""Reset the environment."""
self.game.reset()
@@ -173,8 +179,25 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
infos = {}
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
truncateds["__all__"] = self.game.calculate_truncated()
self._write_step_metadata_json(actions, state, rewards)
return next_obs, rewards, terminateds, truncateds, infos
def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict):
output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata"
output_dir.mkdir(parents=True, exist_ok=True)
path = output_dir / f"step_{self.game.step_counter}.json"
data = {
"episode": self.game.episode_counter,
"step": self.game.step_counter,
"actions": {agent_name: int(action) for agent_name, action in actions.items()},
"reward": rewards,
"state": state,
}
with open(path, "w") as file:
json.dump(data, file)
def _get_obs(self) -> Dict[str, ObsType]:
"""Return the current observation."""
obs = {}

View File

@@ -66,9 +66,9 @@ class ACLRule(SimComponent):
state = super().describe_state()
state["action"] = self.action.value
state["protocol"] = self.protocol.value if self.protocol else None
state["src_ip_address"] = self.src_ip_address if self.src_ip_address else None
state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None
state["src_port"] = self.src_port.value if self.src_port else None
state["dst_ip_address"] = self.dst_ip_address if self.dst_ip_address else None
state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None
state["dst_port"] = self.dst_port.value if self.dst_port else None
return state

View File

@@ -57,7 +57,7 @@ class Switch(Node):
state = super().describe_state()
state["ports"] = {port_num: port.describe_state() for port_num, port in self.switch_ports.items()}
state["num_ports"] = self.num_ports # redundant?
state["mac_address_table"] = {mac: port for mac, port in self.mac_address_table.items()}
state["mac_address_table"] = {mac: port.port_num for mac, port in self.mac_address_table.items()}
return state
def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort):