#2869 - Updates to agents to make sure they can be generated from a given config. Updates to test suite to reflect code changes

This commit is contained in:
Charlie Crane
2024-12-16 15:57:00 +00:00
parent d9a1a0e26f
commit a4fbd29bb4
8 changed files with 83 additions and 63 deletions

View File

@@ -1,5 +1,5 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Any, Dict, Tuple
from typing import Any, Dict, Optional, Tuple
import pytest
import yaml
@@ -10,6 +10,7 @@ from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.agent.scripted_agents.interface import AbstractAgent
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.container import Network
@@ -268,12 +269,12 @@ class ControlledAgent(AbstractAgent, identifier="Controlled_Agent"):
"""Agent that can be controlled by the tests."""
config: "ControlledAgent.ConfigSchema"
most_recent_action: Optional[Tuple[str, Dict]] = None
class ConfigSchema(AbstractAgent.ConfigSchema):
"""Configuration Schema for Abstract Agent used in tests."""
agent_name: str = "Controlled_Agent"
most_recent_action: Tuple[str, Dict]
def get_action(self, obs: None, timestep: int = 0) -> Tuple[str, Dict]:
"""Return the agent's most recent action, formatted in CAOS format."""
@@ -496,12 +497,15 @@ def game_and_agent():
observation_space = ObservationManager(NestedObservation(components={}))
reward_function = RewardFunction()
test_agent = ControlledAgent(
agent_name="test_agent",
action_space=action_space,
observation_space=observation_space,
reward_function=reward_function,
)
config = {
"agent_name":"test_agent",
"action_manager":action_space,
"observation_manager":observation_space,
"reward_function":reward_function,
}
test_agent = ControlledAgent.from_config(config=config)
game.agents["test_agent"] = test_agent