From 3f94c40434d11034cd1dc5dc9e07800c03bbcec4 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Wed, 22 Jan 2025 10:49:42 +0000 Subject: [PATCH] Fix logger inititialisation in agents --- src/primaite/game/agent/interface.py | 7 +++ src/primaite/game/game.py | 1 + .../notebooks/Training-an-SB3-Agent.ipynb | 48 ++++++++++++++++--- 3 files changed, 49 insertions(+), 7 deletions(-) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index aac898e1..6a6bc323 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -68,6 +68,10 @@ class AbstractAgent(BaseModel, ABC): ) reward_function: RewardFunction.ConfigSchema = Field(default_factory=lambda: RewardFunction.ConfigSchema()) + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.logger: AgentLog = AgentLog(agent_name=self.config.ref) + config: "AbstractAgent.ConfigSchema" = Field(default_factory=lambda: AbstractAgent.ConfigSchema()) logger: AgentLog = AgentLog(agent_name="Abstract_Agent") @@ -81,6 +85,7 @@ class AbstractAgent(BaseModel, ABC): def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) + print("cls identifier:", identifier) if identifier is None: return if identifier in cls._registry: @@ -157,6 +162,8 @@ class AbstractAgent(BaseModel, ABC): def from_config(cls, config: Dict) -> AbstractAgent: """Grab the relevant agent class and construct an instance from a config dict.""" agent_type = config["type"] + print("agent_type:", agent_type) + print("cls._registry:", cls._registry) agent_class = cls._registry[agent_type] return agent_class(config=config) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index f59117f4..c828a462 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -493,6 +493,7 @@ class PrimaiteGame: agents_cfg = cfg.get("agents", []) for agent_cfg in agents_cfg: + print("agent_cfg:", agent_cfg) new_agent = AbstractAgent.from_config(agent_cfg) game.agents[agent_cfg["ref"]] = new_agent if isinstance(new_agent, ProxyAgent): diff --git a/src/primaite/notebooks/Training-an-SB3-Agent.ipynb b/src/primaite/notebooks/Training-an-SB3-Agent.ipynb index 2b554475..0e3245ae 100644 --- a/src/primaite/notebooks/Training-an-SB3-Agent.ipynb +++ b/src/primaite/notebooks/Training-an-SB3-Agent.ipynb @@ -29,18 +29,28 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cls identifier: AbstractScriptedAgent\n", + "cls identifier: ProxyAgent\n" + ] + } + ], "source": [ "from primaite.game.game import PrimaiteGame\n", "from primaite.session.environment import PrimaiteGymEnv\n", + "from primaite.game.agent.scripted_agents import probabilistic_agent\n", "import yaml" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -49,7 +59,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -69,9 +79,33 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "agent_cfg: {'ref': 'client_2_green_user', 'team': 'GREEN', 'type': 'ProbabilisticAgent', 'agent_settings': {'action_probabilities': {0: 0.3, 1: 0.6, 2: 0.1}}, 'action_space': {'action_map': {0: {'action': 'do_nothing', 'options': {}}, 1: {'action': 'node_application_execute', 'options': {'node_name': 'client_2', 'application_name': 'WebBrowser'}}, 2: {'action': 'node_application_execute', 'options': {'node_name': 'client_2', 'application_name': 'DatabaseClient'}}}}, 'reward_function': {'reward_components': [{'type': 'WEBPAGE_UNAVAILABLE_PENALTY', 'weight': 0.25, 'options': {'node_hostname': 'client_2'}}, {'type': 'GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY', 'weight': 0.05, 'options': {'node_hostname': 'client_2'}}]}}\n", + "agent_type: ProbabilisticAgent\n", + "cls._registry: {'AbstractScriptedAgent': , 'ProxyAgent': }\n" + ] + }, + { + "ename": "KeyError", + "evalue": "'ProbabilisticAgent'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[4], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m gym \u001b[38;5;241m=\u001b[39m \u001b[43mPrimaiteGymEnv\u001b[49m\u001b[43m(\u001b[49m\u001b[43menv_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcfg\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/arcd/PrimAITE/src/primaite/session/environment.py:71\u001b[0m, in \u001b[0;36mPrimaiteGymEnv.__init__\u001b[0;34m(self, env_config)\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mio \u001b[38;5;241m=\u001b[39m PrimaiteIO\u001b[38;5;241m.\u001b[39mfrom_config(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mepisode_scheduler(\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mio_settings\u001b[39m\u001b[38;5;124m\"\u001b[39m, {}))\n\u001b[1;32m 70\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Handles IO for the environment. This produces sys logs, agent logs, etc.\"\"\"\u001b[39;00m\n\u001b[0;32m---> 71\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgame: PrimaiteGame \u001b[38;5;241m=\u001b[39m \u001b[43mPrimaiteGame\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_config\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepisode_scheduler\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Current game.\"\"\"\u001b[39;00m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_agent_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mnext\u001b[39m(\u001b[38;5;28miter\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgame\u001b[38;5;241m.\u001b[39mrl_agents))\n", + "File \u001b[0;32m~/arcd/PrimAITE/src/primaite/game/game.py:497\u001b[0m, in \u001b[0;36mPrimaiteGame.from_config\u001b[0;34m(cls, cfg)\u001b[0m\n\u001b[1;32m 495\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m agent_cfg \u001b[38;5;129;01min\u001b[39;00m agents_cfg:\n\u001b[1;32m 496\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124magent_cfg:\u001b[39m\u001b[38;5;124m\"\u001b[39m, agent_cfg)\n\u001b[0;32m--> 497\u001b[0m new_agent \u001b[38;5;241m=\u001b[39m \u001b[43mAbstractAgent\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_config\u001b[49m\u001b[43m(\u001b[49m\u001b[43magent_cfg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 498\u001b[0m game\u001b[38;5;241m.\u001b[39magents[agent_cfg[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mref\u001b[39m\u001b[38;5;124m\"\u001b[39m]] \u001b[38;5;241m=\u001b[39m new_agent\n\u001b[1;32m 499\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(new_agent, ProxyAgent):\n", + "File \u001b[0;32m~/arcd/PrimAITE/src/primaite/game/agent/interface.py:163\u001b[0m, in \u001b[0;36mAbstractAgent.from_config\u001b[0;34m(cls, config)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124magent_type:\u001b[39m\u001b[38;5;124m\"\u001b[39m, agent_type)\n\u001b[1;32m 162\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcls._registry:\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_registry)\n\u001b[0;32m--> 163\u001b[0m agent_class \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_registry\u001b[49m\u001b[43m[\u001b[49m\u001b[43magent_type\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m agent_class(config\u001b[38;5;241m=\u001b[39mconfig)\n", + "\u001b[0;31mKeyError\u001b[0m: 'ProbabilisticAgent'" + ] + } + ], "source": [ "gym = PrimaiteGymEnv(env_config=cfg)" ] @@ -191,7 +225,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" } }, "nbformat": 4,