#2085 - Added step metadata json file dumps to the environments. Fixed serialization issues in the Switch and ACLRule classes.
This commit is contained in:
@@ -31,4 +31,6 @@ Outputs
|
||||
|
||||
Running a session creates a session output directory in your user data folder. The filepath looks like this:
|
||||
``~/primaite/3.0.0/sessions/YYYY-MM-DD/HH-MM-SS/``. This folder contains the simulation sys logs generated by each node,
|
||||
the saved agent checkpoints, and final model.
|
||||
the saved agent checkpoints, and final model. The folder also contains a .json file for each episode step that
|
||||
contains the action, reward, and simulation state. These can be found in
|
||||
``~/primaite/3.0.0/sessions/YYYY-MM-DD/HH-MM-SS/simulation_output/episode_<n>/step_metadata/step_<n>.json``
|
||||
|
||||
@@ -5,9 +5,9 @@ import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
|
||||
from primaite import PRIMAITE_PATHS
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
|
||||
|
||||
class PrimaiteGymEnv(gymnasium.Env):
|
||||
@@ -33,17 +33,6 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
self.game.advance_timestep()
|
||||
state = self.game.get_sim_state()
|
||||
|
||||
# Create state suitable for dumping to file.
|
||||
# dump_state = {self.game.episode_counter: {self.game.step_counter: state}}
|
||||
|
||||
# Dump to file
|
||||
# if os.path.isfile(PRIMAITE_PATHS.episode_steps_log_file_path):
|
||||
with open(PRIMAITE_PATHS.episode_log_file_path, "a", encoding="utf-8") as f:
|
||||
# f.write(str(dump_state))
|
||||
# f.write("\n=================\n")
|
||||
# f.flush()
|
||||
json.dump(state, f)
|
||||
|
||||
self.game.update_agents(state)
|
||||
|
||||
next_obs = self._get_obs()
|
||||
@@ -51,9 +40,26 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
terminated = False
|
||||
truncated = self.game.calculate_truncated()
|
||||
info = {}
|
||||
self._write_step_metadata_json(action, state, reward)
|
||||
print(f"Episode: {self.game.episode_counter}, Step: {self.game.step_counter}, Reward: {reward}")
|
||||
return next_obs, reward, terminated, truncated, info
|
||||
|
||||
def _write_step_metadata_json(self, action: int, state: Dict, reward: int):
|
||||
output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata"
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = output_dir / f"step_{self.game.step_counter}.json"
|
||||
|
||||
data = {
|
||||
"episode": self.game.episode_counter,
|
||||
"step": self.game.step_counter,
|
||||
"action": int(action),
|
||||
"reward": int(reward),
|
||||
"state": state,
|
||||
}
|
||||
with open(path, "w") as file:
|
||||
json.dump(data, file)
|
||||
|
||||
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
|
||||
"""Reset the environment."""
|
||||
self.game.reset()
|
||||
@@ -173,8 +179,25 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
infos = {}
|
||||
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
|
||||
truncateds["__all__"] = self.game.calculate_truncated()
|
||||
self._write_step_metadata_json(actions, state, rewards)
|
||||
return next_obs, rewards, terminateds, truncateds, infos
|
||||
|
||||
def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict):
|
||||
output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata"
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = output_dir / f"step_{self.game.step_counter}.json"
|
||||
|
||||
data = {
|
||||
"episode": self.game.episode_counter,
|
||||
"step": self.game.step_counter,
|
||||
"actions": {agent_name: int(action) for agent_name, action in actions.items()},
|
||||
"reward": rewards,
|
||||
"state": state,
|
||||
}
|
||||
with open(path, "w") as file:
|
||||
json.dump(data, file)
|
||||
|
||||
def _get_obs(self) -> Dict[str, ObsType]:
|
||||
"""Return the current observation."""
|
||||
obs = {}
|
||||
|
||||
@@ -66,9 +66,9 @@ class ACLRule(SimComponent):
|
||||
state = super().describe_state()
|
||||
state["action"] = self.action.value
|
||||
state["protocol"] = self.protocol.value if self.protocol else None
|
||||
state["src_ip_address"] = self.src_ip_address if self.src_ip_address else None
|
||||
state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None
|
||||
state["src_port"] = self.src_port.value if self.src_port else None
|
||||
state["dst_ip_address"] = self.dst_ip_address if self.dst_ip_address else None
|
||||
state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None
|
||||
state["dst_port"] = self.dst_port.value if self.dst_port else None
|
||||
return state
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ class Switch(Node):
|
||||
state = super().describe_state()
|
||||
state["ports"] = {port_num: port.describe_state() for port_num, port in self.switch_ports.items()}
|
||||
state["num_ports"] = self.num_ports # redundant?
|
||||
state["mac_address_table"] = {mac: port for mac, port in self.mac_address_table.items()}
|
||||
state["mac_address_table"] = {mac: port.port_num for mac, port in self.mac_address_table.items()}
|
||||
return state
|
||||
|
||||
def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort):
|
||||
|
||||
Reference in New Issue
Block a user