Merged PR 396: Merge fixes for bugs found during testing and updates to docs
- Merge branch 'feature/2588-minimal-deps' into bugfix/2626-too-many-open-files - #2626 fix too many open files bug - Merge remote-tracking branch 'origin/dev' into bugfix/reward-logging - fix reward logging - Merge remote-tracking branch 'origin/dev' into bugfix/reward-logging - get ray to stop crashing - Merged PR 394: Fix 'too many open files' - Merged PR 395: Fix reward logging and get ray to stop complaining - Fix firewall diagram Related work items: #2588, #2626
This commit is contained in:
BIN
docs/_static/firewall_acl.png
vendored
BIN
docs/_static/firewall_acl.png
vendored
Binary file not shown.
|
Before Width: | Height: | Size: 35 KiB After Width: | Height: | Size: 23 KiB |
@@ -19,4 +19,3 @@
|
||||
:recursive:
|
||||
|
||||
primaite
|
||||
tests
|
||||
|
||||
@@ -125,7 +125,6 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE!
|
||||
source/state_system
|
||||
source/request_system
|
||||
PrimAITE API <source/_autosummary/primaite>
|
||||
PrimAITE Tests <source/_autosummary/tests>
|
||||
|
||||
|
||||
.. toctree::
|
||||
|
||||
@@ -78,4 +78,4 @@ Glossary
|
||||
PrimAITE uses the Gymnasium reinforcement learning framework API to create a training environment and interface with RL agents. Gymnasium defines a common way of creating observations, actions, and rewards.
|
||||
|
||||
User app home
|
||||
PrimAITE supports upgrading software version while retaining user data. The user data directory is where configs, notebooks, and results are stored, this location is `~/primaite<version>` on linux/darwin and `C:\\Users\\<username>\\primaite\\<version>` on Windows.
|
||||
PrimAITE supports upgrading software version while retaining user data. The user data directory is where configs, notebooks, and results are stored, this location is `~/primaite<version>/` on linux/darwin and `C:\\Users\\<username>\\primaite<version>` on Windows.
|
||||
|
||||
@@ -54,7 +54,7 @@ license-files = ["LICENSE"]
|
||||
|
||||
[project.optional-dependencies]
|
||||
rl = [
|
||||
"ray[rllib] >= 2.9, < 3",
|
||||
"ray[rllib] >= 2.20.0, < 3",
|
||||
"tensorflow==2.12.0",
|
||||
"stable-baselines3[extra]==2.1.0",
|
||||
]
|
||||
|
||||
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class AgentActionHistoryItem(BaseModel):
|
||||
class AgentHistoryItem(BaseModel):
|
||||
"""One entry of an agent's action log - what the agent did and how the simulator responded in 1 step."""
|
||||
|
||||
timestep: int
|
||||
@@ -32,6 +32,8 @@ class AgentActionHistoryItem(BaseModel):
|
||||
response: RequestResponse
|
||||
"""The response sent back by the simulator for this action."""
|
||||
|
||||
reward: Optional[float] = None
|
||||
|
||||
|
||||
class AgentStartSettings(BaseModel):
|
||||
"""Configuration values for when an agent starts performing actions."""
|
||||
@@ -110,7 +112,7 @@ class AbstractAgent(ABC):
|
||||
self.observation_manager: Optional[ObservationManager] = observation_space
|
||||
self.reward_function: Optional[RewardFunction] = reward_function
|
||||
self.agent_settings = agent_settings or AgentSettings()
|
||||
self.action_history: List[AgentActionHistoryItem] = []
|
||||
self.history: List[AgentHistoryItem] = []
|
||||
|
||||
def update_observation(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
@@ -130,7 +132,7 @@ class AbstractAgent(ABC):
|
||||
:return: Reward from the state.
|
||||
:rtype: float
|
||||
"""
|
||||
return self.reward_function.update(state=state, last_action_response=self.action_history[-1])
|
||||
return self.reward_function.update(state=state, last_action_response=self.history[-1])
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
@@ -161,12 +163,16 @@ class AbstractAgent(ABC):
|
||||
self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse
|
||||
) -> None:
|
||||
"""Process the response from the most recent action."""
|
||||
self.action_history.append(
|
||||
AgentActionHistoryItem(
|
||||
self.history.append(
|
||||
AgentHistoryItem(
|
||||
timestep=timestep, action=action, parameters=parameters, request=request, response=response
|
||||
)
|
||||
)
|
||||
|
||||
def save_reward_to_history(self) -> None:
|
||||
"""Update the most recent history item with the reward value."""
|
||||
self.history[-1].reward = self.reward_function.current_reward
|
||||
|
||||
|
||||
class AbstractScriptedAgent(AbstractAgent):
|
||||
"""Base class for actors which generate their own behaviour."""
|
||||
|
||||
@@ -34,7 +34,7 @@ from primaite import getLogger
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.agent.interface import AgentActionHistoryItem
|
||||
from primaite.game.agent.interface import AgentHistoryItem
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
WhereType = Optional[Iterable[Union[str, int]]]
|
||||
@@ -44,7 +44,7 @@ class AbstractReward:
|
||||
"""Base class for reward function components."""
|
||||
|
||||
@abstractmethod
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state."""
|
||||
return 0.0
|
||||
|
||||
@@ -64,7 +64,7 @@ class AbstractReward:
|
||||
class DummyReward(AbstractReward):
|
||||
"""Dummy reward function component which always returns 0."""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state."""
|
||||
return 0.0
|
||||
|
||||
@@ -104,7 +104,7 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
file_name,
|
||||
]
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
@@ -159,7 +159,7 @@ class WebServer404Penalty(AbstractReward):
|
||||
"""
|
||||
self.location_in_state = ["network", "nodes", node_hostname, "services", service_name]
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
@@ -213,7 +213,7 @@ class WebpageUnavailablePenalty(AbstractReward):
|
||||
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
|
||||
self._last_request_failed: bool = False
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""
|
||||
Calculate the reward based on current simulation state, and the recent agent action.
|
||||
|
||||
@@ -273,7 +273,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
|
||||
self._last_request_failed: bool = False
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""
|
||||
Calculate the reward based on current simulation state, and the recent agent action.
|
||||
|
||||
@@ -343,7 +343,7 @@ class SharedReward(AbstractReward):
|
||||
self.callback: Callable[[str], float] = default_callback
|
||||
"""Method that retrieves an agent's current reward given the agent's name."""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Simply access the other agent's reward and return it."""
|
||||
return self.callback(self.agent_name)
|
||||
|
||||
@@ -389,7 +389,7 @@ class RewardFunction:
|
||||
"""
|
||||
self.reward_components.append((component, weight))
|
||||
|
||||
def update(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
def update(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the overall reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
|
||||
@@ -160,6 +160,7 @@ class PrimaiteGame:
|
||||
agent = self.agents[agent_name]
|
||||
if self.step_counter > 0: # can't get reward before first action
|
||||
agent.update_reward(state=state)
|
||||
agent.save_reward_to_history()
|
||||
agent.update_observation(state=state) # order of this doesn't matter so just use reward order
|
||||
agent.reward_function.total_reward += agent.reward_function.current_reward
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@
|
||||
"# Imports\n",
|
||||
"\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"from primaite.game.agent.interface import AgentActionHistoryItem\n",
|
||||
"from primaite.game.agent.interface import AgentHistoryItem\n",
|
||||
"from primaite.session.environment import PrimaiteGymEnv\n",
|
||||
"import yaml\n",
|
||||
"from pprint import pprint"
|
||||
@@ -63,7 +63,7 @@
|
||||
"source": [
|
||||
"def friendly_output_red_action(info):\n",
|
||||
" # parse the info dict form step output and write out what the red agent is doing\n",
|
||||
" red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_action = red_info.action\n",
|
||||
" if red_action == 'DONOTHING':\n",
|
||||
" red_str = 'DO NOTHING'\n",
|
||||
|
||||
@@ -392,7 +392,7 @@
|
||||
"# Imports\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"from primaite.session.environment import PrimaiteGymEnv\n",
|
||||
"from primaite.game.agent.interface import AgentActionHistoryItem\n",
|
||||
"from primaite.game.agent.interface import AgentHistoryItem\n",
|
||||
"import yaml\n",
|
||||
"from pprint import pprint\n"
|
||||
]
|
||||
@@ -444,7 +444,7 @@
|
||||
"source": [
|
||||
"def friendly_output_red_action(info):\n",
|
||||
" # parse the info dict form step output and write out what the red agent is doing\n",
|
||||
" red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_action = red_info.action\n",
|
||||
" if red_action == 'DONOTHING':\n",
|
||||
" red_str = 'DO NOTHING'\n",
|
||||
@@ -705,7 +705,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.11"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -25,7 +25,7 @@
|
||||
"from primaite.game.game import PrimaiteGame\n",
|
||||
"import yaml\n",
|
||||
"\n",
|
||||
"from primaite.session.environment import PrimaiteRayEnv\n",
|
||||
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
|
||||
"from primaite import PRIMAITE_PATHS\n",
|
||||
"\n",
|
||||
"import ray\n",
|
||||
@@ -60,8 +60,8 @@
|
||||
" policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n",
|
||||
" policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n",
|
||||
" )\n",
|
||||
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg)#, disable_env_checking=True)\n",
|
||||
" .rollouts(num_rollout_workers=0)\n",
|
||||
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg)\n",
|
||||
" .env_runners(num_env_runners=0)\n",
|
||||
" .training(train_batch_size=128)\n",
|
||||
" )\n"
|
||||
]
|
||||
|
||||
@@ -19,7 +19,6 @@
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"\n",
|
||||
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
|
||||
"from ray.rllib.algorithms import ppo\n",
|
||||
"from ray import air, tune\n",
|
||||
"import ray\n",
|
||||
"from ray.rllib.algorithms.ppo import PPOConfig\n",
|
||||
@@ -52,8 +51,8 @@
|
||||
"\n",
|
||||
"config = (\n",
|
||||
" PPOConfig()\n",
|
||||
" .environment(env=PrimaiteRayEnv, env_config=env_config, disable_env_checking=True)\n",
|
||||
" .rollouts(num_rollout_workers=0)\n",
|
||||
" .environment(env=PrimaiteRayEnv, env_config=env_config)\n",
|
||||
" .env_runners(num_env_runners=0)\n",
|
||||
" .training(train_batch_size=128)\n",
|
||||
")\n"
|
||||
]
|
||||
@@ -74,7 +73,7 @@
|
||||
"tune.Tuner(\n",
|
||||
" \"PPO\",\n",
|
||||
" run_config=air.RunConfig(\n",
|
||||
" stop={\"timesteps_total\": 5 * 128}\n",
|
||||
" stop={\"timesteps_total\": 512}\n",
|
||||
" ),\n",
|
||||
" param_space=config\n",
|
||||
").fit()\n"
|
||||
|
||||
@@ -43,7 +43,10 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(data_manipulation_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)"
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
"for agent in cfg['agents']:\n",
|
||||
" if agent['ref'] == 'defender':\n",
|
||||
" agent['agent_settings']['flatten_obs']=True"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -177,7 +180,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -298,8 +298,8 @@
|
||||
"table = PrettyTable()\n",
|
||||
"table.field_names = [\"step\", \"Green Action\", \"Red Action\"]\n",
|
||||
"for i in range(21):\n",
|
||||
" green_action = env.game.agents['green_A'].action_history[i].action\n",
|
||||
" red_action = env.game.agents['red_A'].action_history[i].action\n",
|
||||
" green_action = env.game.agents['green_A'].history[i].action\n",
|
||||
" red_action = env.game.agents['red_A'].history[i].action\n",
|
||||
" table.add_row([i, green_action, red_action])\n",
|
||||
"print(table)"
|
||||
]
|
||||
@@ -329,8 +329,8 @@
|
||||
"table = PrettyTable()\n",
|
||||
"table.field_names = [\"step\", \"Green Action\", \"Red Action\"]\n",
|
||||
"for i in range(21):\n",
|
||||
" green_action = env.game.agents['green_B'].action_history[i].action\n",
|
||||
" red_action = env.game.agents['red_B'].action_history[i].action\n",
|
||||
" green_action = env.game.agents['green_B'].history[i].action\n",
|
||||
" red_action = env.game.agents['red_B'].history[i].action\n",
|
||||
" table.add_row([i, green_action, red_action])\n",
|
||||
"print(table)"
|
||||
]
|
||||
|
||||
@@ -11,6 +11,7 @@ from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler
|
||||
from primaite.session.io import PrimaiteIO
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
from primaite.simulator.system.core.packet_capture import PacketCapture
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -60,7 +61,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
terminated = False
|
||||
truncated = self.game.calculate_truncated()
|
||||
info = {
|
||||
"agent_actions": {name: agent.action_history[-1] for name, agent in self.game.agents.items()}
|
||||
"agent_actions": {name: agent.history[-1] for name, agent in self.game.agents.items()}
|
||||
} # tell us what all the agents did for convenience.
|
||||
if self.game.save_step_metadata:
|
||||
self._write_step_metadata_json(step, action, state, reward)
|
||||
@@ -89,9 +90,10 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
f"avg. reward: {self.agent.reward_function.total_reward}"
|
||||
)
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
self.episode_counter += 1
|
||||
PacketCapture.clear()
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.episode_scheduler(self.episode_counter))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
state = self.game.get_sim_state()
|
||||
@@ -125,5 +127,5 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
def close(self):
|
||||
"""Close the simulation."""
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
|
||||
@@ -87,7 +87,7 @@ class PrimaiteIO:
|
||||
"""Return the path where agent actions will be saved."""
|
||||
return self.session_path / "agent_actions" / f"episode_{episode}.json"
|
||||
|
||||
def write_agent_actions(self, agent_actions: Dict[str, List], episode: int) -> None:
|
||||
def write_agent_log(self, agent_actions: Dict[str, List], episode: int) -> None:
|
||||
"""Take the contents of the agent action log and write it to a file.
|
||||
|
||||
:param episode: Episode number
|
||||
|
||||
@@ -11,6 +11,7 @@ from primaite.session.environment import _LOGGER, PrimaiteGymEnv
|
||||
from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler
|
||||
from primaite.session.io import PrimaiteIO
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
from primaite.simulator.system.core.packet_capture import PacketCapture
|
||||
|
||||
|
||||
class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
@@ -45,7 +46,8 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
self.action_space = gymnasium.spaces.Dict(
|
||||
{name: agent.action_manager.space for name, agent in self.agents.items()}
|
||||
)
|
||||
|
||||
self._obs_space_in_preferred_format = True
|
||||
self._action_space_in_preferred_format = True
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
@@ -59,10 +61,11 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
_LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}")
|
||||
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
|
||||
self.episode_counter += 1
|
||||
PacketCapture.clear()
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
state = self.game.get_sim_state()
|
||||
@@ -138,8 +141,8 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
def close(self):
|
||||
"""Close the simulation."""
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
|
||||
|
||||
class PrimaiteRayEnv(gymnasium.Env):
|
||||
|
||||
@@ -21,6 +21,8 @@ class PacketCapture:
|
||||
The PCAPs are logged to: <simulation output directory>/<hostname>/<hostname>_<ip address>_pcap.log
|
||||
"""
|
||||
|
||||
_logger_instances: List[logging.Logger] = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hostname: str,
|
||||
@@ -65,10 +67,12 @@ class PacketCapture:
|
||||
|
||||
if outbound:
|
||||
self.outbound_logger = logging.getLogger(self._get_logger_name(outbound))
|
||||
PacketCapture._logger_instances.append(self.outbound_logger)
|
||||
logger = self.outbound_logger
|
||||
else:
|
||||
self.inbound_logger = logging.getLogger(self._get_logger_name(outbound))
|
||||
logger = self.inbound_logger
|
||||
PacketCapture._logger_instances.append(self.inbound_logger)
|
||||
|
||||
logger.setLevel(60) # Custom log level > CRITICAL to prevent any unwanted standard DEBUG-CRITICAL logs
|
||||
logger.addHandler(file_handler)
|
||||
@@ -122,3 +126,13 @@ class PacketCapture:
|
||||
if SIM_OUTPUT.save_pcap_logs:
|
||||
msg = frame.model_dump_json()
|
||||
self.outbound_logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL
|
||||
|
||||
@staticmethod
|
||||
def clear():
|
||||
"""Close all open PCAP file handlers."""
|
||||
for logger in PacketCapture._logger_instances:
|
||||
handlers = logger.handlers[:]
|
||||
for handler in handlers:
|
||||
logger.removeHandler(handler)
|
||||
handler.close()
|
||||
PacketCapture._logger_instances = []
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import yaml
|
||||
|
||||
from primaite.game.agent.interface import AgentActionHistoryItem
|
||||
from primaite.game.agent.interface import AgentHistoryItem
|
||||
from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
@@ -75,7 +75,7 @@ def test_uc2_rewards(game_and_agent):
|
||||
state = game.get_sim_state()
|
||||
reward_value = comp.calculate(
|
||||
state,
|
||||
last_action_response=AgentActionHistoryItem(
|
||||
last_action_response=AgentHistoryItem(
|
||||
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
|
||||
),
|
||||
)
|
||||
@@ -91,7 +91,7 @@ def test_uc2_rewards(game_and_agent):
|
||||
state = game.get_sim_state()
|
||||
reward_value = comp.calculate(
|
||||
state,
|
||||
last_action_response=AgentActionHistoryItem(
|
||||
last_action_response=AgentHistoryItem(
|
||||
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
|
||||
),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user