Minor fixes to rewards
This commit is contained in:
@@ -15,9 +15,9 @@ class AbstractReward:
|
||||
def calculate(self, state: Dict) -> float:
|
||||
return 0.0
|
||||
|
||||
@abstractmethod
|
||||
@classmethod
|
||||
def from_config(cls, config:dict) -> "AbstractReward":
|
||||
@abstractmethod
|
||||
def from_config(cls, config:dict, session:"PrimaiteSession") -> "AbstractReward":
|
||||
return cls()
|
||||
|
||||
|
||||
@@ -26,12 +26,12 @@ class DummyReward(AbstractReward):
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "DummyReward":
|
||||
def from_config(cls, config: dict, session:"PrimaiteSession") -> "DummyReward":
|
||||
return cls()
|
||||
|
||||
class DatabaseFileIntegrity(AbstractReward):
|
||||
def __init__(self, node_uuid:str, folder_name:str, file_name:str) -> None:
|
||||
self.location_in_state = ["network", "node", node_uuid, "file_system", ""]
|
||||
self.location_in_state = ["network", "nodes", node_uuid, "file_system", "folders",folder_name, "files", file_name]
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
database_file_state = access_from_nested_dict(state, self.location_in_state)
|
||||
@@ -57,7 +57,7 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
if not file_name:
|
||||
_LOGGER.error(f"{cls.__name__} could not be initialised from config because file_name parameter was not specified")
|
||||
return DummyReward() # TODO: better error handling
|
||||
node_uuid = session.ref_map_nodes[node_ref].uuid
|
||||
node_uuid = session.ref_map_nodes[node_ref]
|
||||
if not node_uuid:
|
||||
_LOGGER.error(f"{cls.__name__} could not be initialised from config because the referenced node could not be found in the simulation")
|
||||
return DummyReward() # TODO: better error handling
|
||||
@@ -66,7 +66,7 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
|
||||
class WebServer404Penalty(AbstractReward):
|
||||
def __init__(self, node_uuid:str, service_uuid:str) -> None:
|
||||
self.location_in_state = ['network','node', node_uuid, 'services', service_uuid]
|
||||
self.location_in_state = ['network','nodes', node_uuid, 'services', service_uuid]
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
web_service_state = access_from_nested_dict(state, self.location_in_state)
|
||||
@@ -86,7 +86,7 @@ class WebServer404Penalty(AbstractReward):
|
||||
msg = f"{cls.__name__} could not be initialised from config because node_ref and service_ref were not found in reward config."
|
||||
_LOGGER.warn(msg)
|
||||
return DummyReward() #TODO: should we error out with incorrect inputs? Probably!
|
||||
node_uuid = session.ref_map_nodes[node_ref].uuid
|
||||
node_uuid = session.ref_map_nodes[node_ref]
|
||||
service_uuid = session.ref_map_services[service_ref].uuid
|
||||
if not (node_uuid and service_uuid):
|
||||
msg = f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not found in the simulator."
|
||||
@@ -124,7 +124,7 @@ class RewardFunction:
|
||||
|
||||
for rew_component_cfg in config["reward_components"]:
|
||||
rew_type = rew_component_cfg["type"]
|
||||
weight = rew_component_cfg["weight"]
|
||||
weight = rew_component_cfg.get("weight",1.0)
|
||||
rew_class = cls.__rew_class_identifiers[rew_type]
|
||||
rew_instance = rew_class.from_config(config=rew_component_cfg.get('options',{}), session=session)
|
||||
new.regsiter_component(component=rew_instance, weight=weight)
|
||||
|
||||
@@ -295,7 +295,7 @@ class PrimaiteSession:
|
||||
|
||||
net.add_node(new_node)
|
||||
new_node.power_on()
|
||||
sess.ref_map_nodes[node_ref] = new_node.uuid
|
||||
sess.ref_map_nodes[node_ref] = new_node.uuid # TODO: fix incosistency with service and link. Node gets added by uuid, but service gets reference to object
|
||||
|
||||
# 2. create links between nodes
|
||||
for link_cfg in links_cfg:
|
||||
@@ -350,7 +350,7 @@ class PrimaiteSession:
|
||||
action_space = ActionManager.from_config(sess, action_space_cfg)
|
||||
|
||||
# CREATE REWARD FUNCTION
|
||||
rew_function = RewardFunction.from_config(reward_function_cfg)
|
||||
rew_function = RewardFunction.from_config(reward_function_cfg, session=sess)
|
||||
|
||||
# CREATE AGENT
|
||||
if agent_type == "GreenWebBrowsingAgent":
|
||||
|
||||
Reference in New Issue
Block a user