diff --git a/docs/source/how_to_guides/extensible_agents.rst b/docs/source/how_to_guides/extensible_agents.rst index c653bd05..2dc70ca6 100644 --- a/docs/source/how_to_guides/extensible_agents.rst +++ b/docs/source/how_to_guides/extensible_agents.rst @@ -20,7 +20,6 @@ The inheritance structure of agents within PrimAITE are shown below. When develo All agent types within PrimAITE are listed under the ``_registry`` attribute of the parent class, ``AbstractAgent``. # TODO: Turn this into an inheritance diagram -# TODO: Would this be necessary? AbstractAgent | @@ -61,7 +60,7 @@ AbstractAgent class ConfigSchema(AbstractAgent.ConfigSchema): """ExampleAgent configuration schema""" - agent_name: str = "ExampleAgent + type: str = "ExampleAgent """Name of agent""" starting_host: int """Host node that this agent should start from in the given environment.""" @@ -97,7 +96,6 @@ AbstractAgent start_step: 25 frequency: 20 variance: 5 - agent_name: "Example Agent" starting_host: "Server_1" diff --git a/src/primaite/game/agent/agent_log.py b/src/primaite/game/agent/agent_log.py index 98c6a337..31d74176 100644 --- a/src/primaite/game/agent/agent_log.py +++ b/src/primaite/game/agent/agent_log.py @@ -21,9 +21,10 @@ class _NotJSONFilter(logging.Filter): class AgentLog: """ - A Agent Log class is a simple logger dedicated to managing and writing logging updates and information for an agent. + An Agent Log class is a simple logger dedicated to managing and writing updates and information for an agent. - Each log message is written to a file located at: /agent_name/agent_name.log + Each log message is written to a file located at: + /agent_name/agent_name.log """ def __init__(self, agent_name: Optional[str]): diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 794ce511..370e6bbb 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -115,6 +115,8 @@ class AbstractAgent(BaseModel): @classmethod def from_config(cls, config: Dict) -> "AbstractAgent": """Creates an agent component from a configuration dictionary.""" + if config["type"] not in cls._registry: + return ValueError(f"Invalid Agent Type: {config['type']}") obj = cls( config=cls.ConfigSchema(**config["agent_settings"]), action_manager=ActionManager.from_config(config["game"], config["action_manager"]), diff --git a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py index 455c996b..f3d9ee08 100644 --- a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py +++ b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py @@ -21,7 +21,7 @@ class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent") class ConfigSchema(AbstractScriptedAgent.ConfigSchema): """Configuration schema for Probabilistic Agent.""" - agent_name: str = "ProbabilisticAgent" + type: str = "ProbabilisticAgent" action_probabilities: Dict[int, float] = None """Probability to perform each action in the action map. The sum of probabilities should sum to 1.""" diff --git a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py index 69540f0a..6e5fb94d 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py @@ -61,16 +61,6 @@ def test_probabilistic_agent(): reward_function_cfg = {} - # pa = ProbabilisticAgent( - # agent_name="test_agent", - # action_space=action_space, - # observation_space=observation_space, - # reward_function=reward_function, - # settings={ - # "action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE}, - # }, - # ) - pa_config = { "agent_name": "test_agent", "game": PrimaiteGame(),