Fix logger inititialisation in agents
This commit is contained in:
@@ -68,6 +68,10 @@ class AbstractAgent(BaseModel, ABC):
|
||||
)
|
||||
reward_function: RewardFunction.ConfigSchema = Field(default_factory=lambda: RewardFunction.ConfigSchema())
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.logger: AgentLog = AgentLog(agent_name=self.config.ref)
|
||||
|
||||
config: "AbstractAgent.ConfigSchema" = Field(default_factory=lambda: AbstractAgent.ConfigSchema())
|
||||
|
||||
logger: AgentLog = AgentLog(agent_name="Abstract_Agent")
|
||||
@@ -81,6 +85,7 @@ class AbstractAgent(BaseModel, ABC):
|
||||
|
||||
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
print("cls identifier:", identifier)
|
||||
if identifier is None:
|
||||
return
|
||||
if identifier in cls._registry:
|
||||
@@ -157,6 +162,8 @@ class AbstractAgent(BaseModel, ABC):
|
||||
def from_config(cls, config: Dict) -> AbstractAgent:
|
||||
"""Grab the relevant agent class and construct an instance from a config dict."""
|
||||
agent_type = config["type"]
|
||||
print("agent_type:", agent_type)
|
||||
print("cls._registry:", cls._registry)
|
||||
agent_class = cls._registry[agent_type]
|
||||
return agent_class(config=config)
|
||||
|
||||
|
||||
@@ -493,6 +493,7 @@ class PrimaiteGame:
|
||||
agents_cfg = cfg.get("agents", [])
|
||||
|
||||
for agent_cfg in agents_cfg:
|
||||
print("agent_cfg:", agent_cfg)
|
||||
new_agent = AbstractAgent.from_config(agent_cfg)
|
||||
game.agents[agent_cfg["ref"]] = new_agent
|
||||
if isinstance(new_agent, ProxyAgent):
|
||||
|
||||
@@ -29,18 +29,28 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"cls identifier: AbstractScriptedAgent\n",
|
||||
"cls identifier: ProxyAgent\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from primaite.game.game import PrimaiteGame\n",
|
||||
"from primaite.session.environment import PrimaiteGymEnv\n",
|
||||
"from primaite.game.agent.scripted_agents import probabilistic_agent\n",
|
||||
"import yaml"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -49,7 +59,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -69,9 +79,33 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"agent_cfg: {'ref': 'client_2_green_user', 'team': 'GREEN', 'type': 'ProbabilisticAgent', 'agent_settings': {'action_probabilities': {0: 0.3, 1: 0.6, 2: 0.1}}, 'action_space': {'action_map': {0: {'action': 'do_nothing', 'options': {}}, 1: {'action': 'node_application_execute', 'options': {'node_name': 'client_2', 'application_name': 'WebBrowser'}}, 2: {'action': 'node_application_execute', 'options': {'node_name': 'client_2', 'application_name': 'DatabaseClient'}}}}, 'reward_function': {'reward_components': [{'type': 'WEBPAGE_UNAVAILABLE_PENALTY', 'weight': 0.25, 'options': {'node_hostname': 'client_2'}}, {'type': 'GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY', 'weight': 0.05, 'options': {'node_hostname': 'client_2'}}]}}\n",
|
||||
"agent_type: ProbabilisticAgent\n",
|
||||
"cls._registry: {'AbstractScriptedAgent': <class 'primaite.game.agent.interface.AbstractScriptedAgent'>, 'ProxyAgent': <class 'primaite.game.agent.interface.ProxyAgent'>}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "KeyError",
|
||||
"evalue": "'ProbabilisticAgent'",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[4], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m gym \u001b[38;5;241m=\u001b[39m \u001b[43mPrimaiteGymEnv\u001b[49m\u001b[43m(\u001b[49m\u001b[43menv_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcfg\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/arcd/PrimAITE/src/primaite/session/environment.py:71\u001b[0m, in \u001b[0;36mPrimaiteGymEnv.__init__\u001b[0;34m(self, env_config)\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mio \u001b[38;5;241m=\u001b[39m PrimaiteIO\u001b[38;5;241m.\u001b[39mfrom_config(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mepisode_scheduler(\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mio_settings\u001b[39m\u001b[38;5;124m\"\u001b[39m, {}))\n\u001b[1;32m 70\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Handles IO for the environment. This produces sys logs, agent logs, etc.\"\"\"\u001b[39;00m\n\u001b[0;32m---> 71\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgame: PrimaiteGame \u001b[38;5;241m=\u001b[39m \u001b[43mPrimaiteGame\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_config\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepisode_scheduler\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Current game.\"\"\"\u001b[39;00m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_agent_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mnext\u001b[39m(\u001b[38;5;28miter\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgame\u001b[38;5;241m.\u001b[39mrl_agents))\n",
|
||||
"File \u001b[0;32m~/arcd/PrimAITE/src/primaite/game/game.py:497\u001b[0m, in \u001b[0;36mPrimaiteGame.from_config\u001b[0;34m(cls, cfg)\u001b[0m\n\u001b[1;32m 495\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m agent_cfg \u001b[38;5;129;01min\u001b[39;00m agents_cfg:\n\u001b[1;32m 496\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124magent_cfg:\u001b[39m\u001b[38;5;124m\"\u001b[39m, agent_cfg)\n\u001b[0;32m--> 497\u001b[0m new_agent \u001b[38;5;241m=\u001b[39m \u001b[43mAbstractAgent\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_config\u001b[49m\u001b[43m(\u001b[49m\u001b[43magent_cfg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 498\u001b[0m game\u001b[38;5;241m.\u001b[39magents[agent_cfg[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mref\u001b[39m\u001b[38;5;124m\"\u001b[39m]] \u001b[38;5;241m=\u001b[39m new_agent\n\u001b[1;32m 499\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(new_agent, ProxyAgent):\n",
|
||||
"File \u001b[0;32m~/arcd/PrimAITE/src/primaite/game/agent/interface.py:163\u001b[0m, in \u001b[0;36mAbstractAgent.from_config\u001b[0;34m(cls, config)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124magent_type:\u001b[39m\u001b[38;5;124m\"\u001b[39m, agent_type)\n\u001b[1;32m 162\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcls._registry:\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_registry)\n\u001b[0;32m--> 163\u001b[0m agent_class \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_registry\u001b[49m\u001b[43m[\u001b[49m\u001b[43magent_type\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m agent_class(config\u001b[38;5;241m=\u001b[39mconfig)\n",
|
||||
"\u001b[0;31mKeyError\u001b[0m: 'ProbabilisticAgent'"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"gym = PrimaiteGymEnv(env_config=cfg)"
|
||||
]
|
||||
@@ -191,7 +225,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.11"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
Reference in New Issue
Block a user