Minor fixes to rewards

This commit is contained in:
Marek Wolan
2023-10-12 09:59:45 +01:00
parent 70c1857bbc
commit 565af11dba
2 changed files with 10 additions and 10 deletions

View File

@@ -15,9 +15,9 @@ class AbstractReward:
def calculate(self, state: Dict) -> float:
return 0.0
@abstractmethod
@classmethod
def from_config(cls, config:dict) -> "AbstractReward":
@abstractmethod
def from_config(cls, config:dict, session:"PrimaiteSession") -> "AbstractReward":
return cls()
@@ -26,12 +26,12 @@ class DummyReward(AbstractReward):
return 0.0
@classmethod
def from_config(cls, config: dict) -> "DummyReward":
def from_config(cls, config: dict, session:"PrimaiteSession") -> "DummyReward":
return cls()
class DatabaseFileIntegrity(AbstractReward):
def __init__(self, node_uuid:str, folder_name:str, file_name:str) -> None:
self.location_in_state = ["network", "node", node_uuid, "file_system", ""]
self.location_in_state = ["network", "nodes", node_uuid, "file_system", "folders",folder_name, "files", file_name]
def calculate(self, state: Dict) -> float:
database_file_state = access_from_nested_dict(state, self.location_in_state)
@@ -57,7 +57,7 @@ class DatabaseFileIntegrity(AbstractReward):
if not file_name:
_LOGGER.error(f"{cls.__name__} could not be initialised from config because file_name parameter was not specified")
return DummyReward() # TODO: better error handling
node_uuid = session.ref_map_nodes[node_ref].uuid
node_uuid = session.ref_map_nodes[node_ref]
if not node_uuid:
_LOGGER.error(f"{cls.__name__} could not be initialised from config because the referenced node could not be found in the simulation")
return DummyReward() # TODO: better error handling
@@ -66,7 +66,7 @@ class DatabaseFileIntegrity(AbstractReward):
class WebServer404Penalty(AbstractReward):
def __init__(self, node_uuid:str, service_uuid:str) -> None:
self.location_in_state = ['network','node', node_uuid, 'services', service_uuid]
self.location_in_state = ['network','nodes', node_uuid, 'services', service_uuid]
def calculate(self, state: Dict) -> float:
web_service_state = access_from_nested_dict(state, self.location_in_state)
@@ -86,7 +86,7 @@ class WebServer404Penalty(AbstractReward):
msg = f"{cls.__name__} could not be initialised from config because node_ref and service_ref were not found in reward config."
_LOGGER.warn(msg)
return DummyReward() #TODO: should we error out with incorrect inputs? Probably!
node_uuid = session.ref_map_nodes[node_ref].uuid
node_uuid = session.ref_map_nodes[node_ref]
service_uuid = session.ref_map_services[service_ref].uuid
if not (node_uuid and service_uuid):
msg = f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not found in the simulator."
@@ -124,7 +124,7 @@ class RewardFunction:
for rew_component_cfg in config["reward_components"]:
rew_type = rew_component_cfg["type"]
weight = rew_component_cfg["weight"]
weight = rew_component_cfg.get("weight",1.0)
rew_class = cls.__rew_class_identifiers[rew_type]
rew_instance = rew_class.from_config(config=rew_component_cfg.get('options',{}), session=session)
new.regsiter_component(component=rew_instance, weight=weight)

View File

@@ -295,7 +295,7 @@ class PrimaiteSession:
net.add_node(new_node)
new_node.power_on()
sess.ref_map_nodes[node_ref] = new_node.uuid
sess.ref_map_nodes[node_ref] = new_node.uuid # TODO: fix incosistency with service and link. Node gets added by uuid, but service gets reference to object
# 2. create links between nodes
for link_cfg in links_cfg:
@@ -350,7 +350,7 @@ class PrimaiteSession:
action_space = ActionManager.from_config(sess, action_space_cfg)
# CREATE REWARD FUNCTION
rew_function = RewardFunction.from_config(reward_function_cfg)
rew_function = RewardFunction.from_config(reward_function_cfg, session=sess)
# CREATE AGENT
if agent_type == "GreenWebBrowsingAgent":