Merge remote-tracking branch 'origin/dev' into feature/898-Fix-the-functionality-of-resetting-a-node
This commit is contained in:
0
src/primaite/agents/__init__.py
Normal file
0
src/primaite/agents/__init__.py
Normal file
127
src/primaite/agents/utils.py
Normal file
127
src/primaite/agents/utils.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from primaite.common.enums import NodeHardwareAction, NodePOLType, NodeSoftwareAction
|
||||
|
||||
|
||||
def transform_action_node_readable(action):
|
||||
"""
|
||||
Convert a node action from enumerated format to readable format.
|
||||
|
||||
example:
|
||||
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
|
||||
"""
|
||||
action_node_property = NodePOLType(action[1]).name
|
||||
|
||||
if action_node_property == "OPERATING":
|
||||
property_action = NodeHardwareAction(action[2]).name
|
||||
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[
|
||||
2
|
||||
] <= 1:
|
||||
property_action = NodeSoftwareAction(action[2]).name
|
||||
else:
|
||||
property_action = "NONE"
|
||||
|
||||
new_action = [action[0], action_node_property, property_action, action[3]]
|
||||
return new_action
|
||||
|
||||
|
||||
def transform_action_acl_readable(action):
|
||||
"""
|
||||
Transform an ACL action to a more readable format.
|
||||
|
||||
example:
|
||||
[0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1]
|
||||
"""
|
||||
action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"}
|
||||
action_permissions = {0: "DENY", 1: "ALLOW"}
|
||||
|
||||
action_decision = action_decisions[action[0]]
|
||||
action_permission = action_permissions[action[1]]
|
||||
|
||||
# For IPs, Ports and Protocols, 0 means any, otherwise its just an index
|
||||
new_action = [action_decision, action_permission] + list(action[2:6])
|
||||
for n, val in enumerate(list(action[2:6])):
|
||||
if val == 0:
|
||||
new_action[n + 2] = "ANY"
|
||||
|
||||
return new_action
|
||||
|
||||
|
||||
def is_valid_node_action(action):
|
||||
"""Is the node action an actual valid action.
|
||||
|
||||
Only uses information about the action to determine if the action has an effect
|
||||
|
||||
Does NOT consider:
|
||||
- Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch
|
||||
- Node already being in that state (turning an ON node ON)
|
||||
"""
|
||||
action_r = transform_action_node_readable(action)
|
||||
|
||||
node_property = action_r[1]
|
||||
node_action = action_r[2]
|
||||
|
||||
# print("node property", node_property, "\nnode action", node_action)
|
||||
|
||||
if node_property == "NONE":
|
||||
return False
|
||||
if node_action == "NONE":
|
||||
return False
|
||||
if node_property == "OPERATING" and node_action == "PATCHING":
|
||||
# Operating State cannot PATCH
|
||||
return False
|
||||
if node_property != "OPERATING" and node_action not in ["NONE", "PATCHING"]:
|
||||
# Software States can only do Nothing or Patch
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_valid_acl_action(action):
|
||||
"""
|
||||
Is the ACL action an actual valid action.
|
||||
|
||||
Only uses information about the action to determine if the action has an effect.
|
||||
|
||||
Does NOT consider:
|
||||
- Trying to create identical rules
|
||||
- Trying to create a rule which is a subset of another rule (caused by "ANY")
|
||||
"""
|
||||
action_r = transform_action_acl_readable(action)
|
||||
|
||||
action_decision = action_r[0]
|
||||
action_permission = action_r[1]
|
||||
action_source_id = action_r[2]
|
||||
action_destination_id = action_r[3]
|
||||
|
||||
if action_decision == "NONE":
|
||||
return False
|
||||
if (
|
||||
action_source_id == action_destination_id
|
||||
and action_source_id != "ANY"
|
||||
and action_destination_id != "ANY"
|
||||
):
|
||||
# ACL rule towards itself
|
||||
return False
|
||||
if action_permission == "DENY":
|
||||
# DENY is unnecessary, we can create and delete allow rules instead
|
||||
# No allow rule = blocked/DENY by feault. ALLOW overrides existing DENY.
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def is_valid_acl_action_extra(action):
|
||||
"""Harsher version of valid acl actions, does not allow action."""
|
||||
if is_valid_acl_action(action) is False:
|
||||
return False
|
||||
|
||||
action_r = transform_action_acl_readable(action)
|
||||
action_protocol = action_r[4]
|
||||
action_port = action_r[5]
|
||||
|
||||
# Don't allow protocols or ports to be ANY
|
||||
# in the future we might want to do the opposite, and only have ANY option for ports and service
|
||||
if action_protocol == "ANY":
|
||||
return False
|
||||
if action_port == "ANY":
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -9,6 +9,7 @@ class ConfigValuesMain(object):
|
||||
"""Init."""
|
||||
# Generic
|
||||
self.agent_identifier = "" # the agent in use
|
||||
self.observation_config = None # observation space config
|
||||
self.num_episodes = 0 # number of episodes to train over
|
||||
self.num_steps = 0 # number of steps in an episode
|
||||
self.time_delay = 0 # delay between steps (ms) - applies to generic agents only
|
||||
|
||||
@@ -51,6 +51,7 @@ class SoftwareState(Enum):
|
||||
class NodePOLType(Enum):
|
||||
"""Node Pattern of Life type enumeration."""
|
||||
|
||||
NONE = 0
|
||||
OPERATING = 1
|
||||
OS = 2
|
||||
SERVICE = 3
|
||||
@@ -83,6 +84,7 @@ class ActionType(Enum):
|
||||
|
||||
NODE = 0
|
||||
ACL = 1
|
||||
ANY = 2
|
||||
|
||||
|
||||
class ObservationType(Enum):
|
||||
@@ -100,3 +102,29 @@ class FileSystemState(Enum):
|
||||
DESTROYED = 3
|
||||
REPAIRING = 4
|
||||
RESTORING = 5
|
||||
|
||||
|
||||
class NodeHardwareAction(Enum):
|
||||
"""Node hardware action."""
|
||||
|
||||
NONE = 0
|
||||
ON = 1
|
||||
OFF = 2
|
||||
RESET = 3
|
||||
|
||||
|
||||
class NodeSoftwareAction(Enum):
|
||||
"""Node software action."""
|
||||
|
||||
NONE = 0
|
||||
PATCHING = 1
|
||||
|
||||
|
||||
class LinkStatus(Enum):
|
||||
"""Link traffic status."""
|
||||
|
||||
NONE = 0
|
||||
LOW = 1
|
||||
MEDIUM = 2
|
||||
HIGH = 3
|
||||
OVERLOAD = 4
|
||||
|
||||
403
src/primaite/environment/observations.py
Normal file
403
src/primaite/environment/observations.py
Normal file
@@ -0,0 +1,403 @@
|
||||
"""Module for handling configurable observation spaces in PrimAITE."""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Dict, Final, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
# This dependency is only needed for type hints,
|
||||
# TYPE_CHECKING is False at runtime and True when typecheckers are performing typechecking
|
||||
# Therefore, this avoids circular dependency problem.
|
||||
if TYPE_CHECKING:
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbstractObservationComponent(ABC):
|
||||
"""Represents a part of the PrimAITE observation space."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, env: "Primaite"):
|
||||
_LOGGER.info(f"Initialising {self} observation component")
|
||||
self.env: "Primaite" = env
|
||||
self.space: spaces.Space
|
||||
self.current_observation: np.ndarray # type might be too restrictive?
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def update(self):
|
||||
"""Update the observation based on the current state of the environment."""
|
||||
self.current_observation = NotImplemented
|
||||
|
||||
|
||||
class NodeLinkTable(AbstractObservationComponent):
|
||||
"""Table with nodes and links as rows and hardware/software status as cols.
|
||||
|
||||
This will create the observation space formatted as a table of integers.
|
||||
There is one row per node, followed by one row per link.
|
||||
The number of columns is 4 plus one per service. They are:
|
||||
* node/link ID
|
||||
* node hardware status / 0 for links
|
||||
* node operating system status (if active/service) / 0 for links
|
||||
* node file system status (active/service only) / 0 for links
|
||||
* node service1 status / traffic load from that service for links
|
||||
* node service2 status / traffic load from that service for links
|
||||
* ...
|
||||
* node serviceN status / traffic load from that service for links
|
||||
|
||||
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
|
||||
``(12, 7)``
|
||||
"""
|
||||
|
||||
_FIXED_PARAMETERS: int = 4
|
||||
_MAX_VAL: int = 1_000_000
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite"):
|
||||
super().__init__(env)
|
||||
|
||||
# 1. Define the shape of your observation space component
|
||||
num_items = self.env.num_links + self.env.num_nodes
|
||||
num_columns = self.env.num_services + self._FIXED_PARAMETERS
|
||||
observation_shape = (num_items, num_columns)
|
||||
|
||||
# 2. Create Observation space
|
||||
self.space = spaces.Box(
|
||||
low=0,
|
||||
high=self._MAX_VAL,
|
||||
shape=observation_shape,
|
||||
dtype=self._DATA_TYPE,
|
||||
)
|
||||
|
||||
# 3. Initialise Observation with zeroes
|
||||
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
|
||||
|
||||
def update(self):
|
||||
"""Update the observation based on current environment state.
|
||||
|
||||
The structure of the observation space is described in :class:`.NodeLinkTable`
|
||||
"""
|
||||
item_index = 0
|
||||
nodes = self.env.nodes
|
||||
links = self.env.links
|
||||
# Do nodes first
|
||||
for _, node in nodes.items():
|
||||
self.current_observation[item_index][0] = int(node.node_id)
|
||||
self.current_observation[item_index][1] = node.hardware_state.value
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
self.current_observation[item_index][2] = node.software_state.value
|
||||
self.current_observation[item_index][
|
||||
3
|
||||
] = node.file_system_state_observed.value
|
||||
else:
|
||||
self.current_observation[item_index][2] = 0
|
||||
self.current_observation[item_index][3] = 0
|
||||
service_index = 4
|
||||
if isinstance(node, ServiceNode):
|
||||
for service in self.env.services_list:
|
||||
if node.has_service(service):
|
||||
self.current_observation[item_index][
|
||||
service_index
|
||||
] = node.get_service_state(service).value
|
||||
else:
|
||||
self.current_observation[item_index][service_index] = 0
|
||||
service_index += 1
|
||||
else:
|
||||
# Not a service node
|
||||
for service in self.env.services_list:
|
||||
self.current_observation[item_index][service_index] = 0
|
||||
service_index += 1
|
||||
item_index += 1
|
||||
|
||||
# Now do links
|
||||
for _, link in links.items():
|
||||
self.current_observation[item_index][0] = int(link.get_id())
|
||||
self.current_observation[item_index][1] = 0
|
||||
self.current_observation[item_index][2] = 0
|
||||
self.current_observation[item_index][3] = 0
|
||||
protocol_list = link.get_protocol_list()
|
||||
protocol_index = 0
|
||||
for protocol in protocol_list:
|
||||
self.current_observation[item_index][
|
||||
protocol_index + 4
|
||||
] = protocol.get_load()
|
||||
protocol_index += 1
|
||||
item_index += 1
|
||||
|
||||
|
||||
class NodeStatuses(AbstractObservationComponent):
|
||||
"""Flat list of nodes' hardware, OS, file system, and service states.
|
||||
|
||||
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
|
||||
integers.
|
||||
Each node has 3 elements plus 1 per service. It will have the following structure:
|
||||
.. code-block::
|
||||
[
|
||||
node1 hardware state,
|
||||
node1 OS state,
|
||||
node1 file system state,
|
||||
node1 service1 state,
|
||||
node1 service2 state,
|
||||
node1 serviceN state (one for each service),
|
||||
node2 hardware state,
|
||||
node2 OS state,
|
||||
node2 file system state,
|
||||
node2 service1 state,
|
||||
node2 service2 state,
|
||||
node2 serviceN state (one for each service),
|
||||
...
|
||||
]
|
||||
|
||||
:param env: The environment that forms the basis of the observations
|
||||
:type env: Primaite
|
||||
"""
|
||||
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite"):
|
||||
super().__init__(env)
|
||||
|
||||
# 1. Define the shape of your observation space component
|
||||
node_shape = [
|
||||
len(HardwareState) + 1,
|
||||
len(SoftwareState) + 1,
|
||||
len(FileSystemState) + 1,
|
||||
]
|
||||
services_shape = [len(SoftwareState) + 1] * self.env.num_services
|
||||
node_shape = node_shape + services_shape
|
||||
|
||||
shape = node_shape * self.env.num_nodes
|
||||
# 2. Create Observation space
|
||||
self.space = spaces.MultiDiscrete(shape)
|
||||
|
||||
# 3. Initialise observation with zeroes
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
|
||||
def update(self):
|
||||
"""Update the observation based on current environment state.
|
||||
|
||||
The structure of the observation space is described in :class:`.NodeStatuses`
|
||||
"""
|
||||
obs = []
|
||||
for _, node in self.env.nodes.items():
|
||||
hardware_state = node.hardware_state.value
|
||||
software_state = 0
|
||||
file_system_state = 0
|
||||
service_states = [0] * self.env.num_services
|
||||
|
||||
if isinstance(node, ActiveNode):
|
||||
software_state = node.software_state.value
|
||||
file_system_state = node.file_system_state_observed.value
|
||||
|
||||
if isinstance(node, ServiceNode):
|
||||
for i, service in enumerate(self.env.services_list):
|
||||
if node.has_service(service):
|
||||
service_states[i] = node.get_service_state(service).value
|
||||
obs.extend(
|
||||
[hardware_state, software_state, file_system_state, *service_states]
|
||||
)
|
||||
self.current_observation[:] = obs
|
||||
|
||||
|
||||
class LinkTrafficLevels(AbstractObservationComponent):
|
||||
"""Flat list of traffic levels encoded into banded categories.
|
||||
|
||||
For each link, total traffic or traffic per service is encoded into a categorical value.
|
||||
For example, if ``quantisation_levels=5``, the traffic levels represent these values:
|
||||
0 = No traffic (0% of bandwidth)
|
||||
1 = No traffic (0%-33% of bandwidth)
|
||||
2 = No traffic (33%-66% of bandwidth)
|
||||
3 = No traffic (66%-100% of bandwidth)
|
||||
4 = No traffic (100% of bandwidth)
|
||||
|
||||
.. note::
|
||||
The lowest category always corresponds to no traffic and the highest category to the link being at max capacity.
|
||||
Any amount of traffic between 0% and 100% (exclusive) is divided evenly into the remaining categories.
|
||||
|
||||
:param env: The environment that forms the basis of the observations
|
||||
:type env: Primaite
|
||||
:param combine_service_traffic: Whether to consider total traffic on the link, or each protocol individually,
|
||||
defaults to False
|
||||
:type combine_service_traffic: bool, optional
|
||||
:param quantisation_levels: How many bands to consider when converting the traffic amount to a categorical value ,
|
||||
defaults to 5
|
||||
:type quantisation_levels: int, optional
|
||||
"""
|
||||
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: "Primaite",
|
||||
combine_service_traffic: bool = False,
|
||||
quantisation_levels: int = 5,
|
||||
):
|
||||
if quantisation_levels < 3:
|
||||
_msg = (
|
||||
f"quantisation_levels must be 3 or more because the lowest and highest levels are "
|
||||
f"reserved for 0% and 100% link utilisation, got {quantisation_levels} instead. "
|
||||
f"Resetting to default value (5)"
|
||||
)
|
||||
_LOGGER.warning(_msg)
|
||||
quantisation_levels = 5
|
||||
|
||||
super().__init__(env)
|
||||
|
||||
self._combine_service_traffic: bool = combine_service_traffic
|
||||
self._quantisation_levels: int = quantisation_levels
|
||||
self._entries_per_link: int = 1
|
||||
|
||||
if not self._combine_service_traffic:
|
||||
self._entries_per_link = self.env.num_services
|
||||
|
||||
# 1. Define the shape of your observation space component
|
||||
shape = (
|
||||
[self._quantisation_levels] * self.env.num_links * self._entries_per_link
|
||||
)
|
||||
|
||||
# 2. Create Observation space
|
||||
self.space = spaces.MultiDiscrete(shape)
|
||||
|
||||
# 3. Initialise observation with zeroes
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
|
||||
def update(self):
|
||||
"""Update the observation based on current environment state.
|
||||
|
||||
The structure of the observation space is described in :class:`.LinkTrafficLevels`
|
||||
"""
|
||||
obs = []
|
||||
for _, link in self.env.links.items():
|
||||
bandwidth = link.bandwidth
|
||||
if self._combine_service_traffic:
|
||||
loads = [link.get_current_load()]
|
||||
else:
|
||||
loads = [protocol.get_load() for protocol in link.protocol_list]
|
||||
|
||||
for load in loads:
|
||||
if load <= 0:
|
||||
traffic_level = 0
|
||||
elif load >= bandwidth:
|
||||
traffic_level = self._quantisation_levels - 1
|
||||
else:
|
||||
traffic_level = (load / bandwidth) // (
|
||||
1 / (self._quantisation_levels - 2)
|
||||
) + 1
|
||||
|
||||
obs.append(int(traffic_level))
|
||||
|
||||
self.current_observation[:] = obs
|
||||
|
||||
|
||||
class ObservationsHandler:
|
||||
"""Component-based observation space handler.
|
||||
|
||||
This allows users to configure observation spaces by mixing and matching components.
|
||||
Each component can also define further parameters to make them more flexible.
|
||||
"""
|
||||
|
||||
_REGISTRY: Final[Dict[str, type]] = {
|
||||
"NODE_LINK_TABLE": NodeLinkTable,
|
||||
"NODE_STATUSES": NodeStatuses,
|
||||
"LINK_TRAFFIC_LEVELS": LinkTrafficLevels,
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.registered_obs_components: List[AbstractObservationComponent] = []
|
||||
self.space: spaces.Space
|
||||
self.current_observation: Union[Tuple[np.ndarray], np.ndarray]
|
||||
|
||||
def update_obs(self):
|
||||
"""Fetch fresh information about the environment."""
|
||||
current_obs = []
|
||||
for obs in self.registered_obs_components:
|
||||
obs.update()
|
||||
current_obs.append(obs.current_observation)
|
||||
|
||||
# If there is only one component, don't use a tuple, just pass through that component's obs.
|
||||
if len(current_obs) == 1:
|
||||
self.current_observation = current_obs[0]
|
||||
else:
|
||||
self.current_observation = tuple(current_obs)
|
||||
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
|
||||
|
||||
def register(self, obs_component: AbstractObservationComponent):
|
||||
"""Add a component for this handler to track.
|
||||
|
||||
:param obs_component: The component to add.
|
||||
:type obs_component: AbstractObservationComponent
|
||||
"""
|
||||
self.registered_obs_components.append(obs_component)
|
||||
self.update_space()
|
||||
|
||||
def deregister(self, obs_component: AbstractObservationComponent):
|
||||
"""Remove a component from this handler.
|
||||
|
||||
:param obs_component: Which component to remove. It must exist within this object's
|
||||
``registered_obs_components`` attribute.
|
||||
:type obs_component: AbstractObservationComponent
|
||||
"""
|
||||
self.registered_obs_components.remove(obs_component)
|
||||
self.update_space()
|
||||
|
||||
def update_space(self):
|
||||
"""Rebuild the handler's composite observation space from its components."""
|
||||
component_spaces = []
|
||||
for obs_comp in self.registered_obs_components:
|
||||
component_spaces.append(obs_comp.space)
|
||||
|
||||
# If there is only one component, don't use a tuple space, just pass through that component's space.
|
||||
if len(component_spaces) == 1:
|
||||
self.space = component_spaces[0]
|
||||
else:
|
||||
self.space = spaces.Tuple(component_spaces)
|
||||
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, env: "Primaite", obs_space_config: dict):
|
||||
"""Parse a config dictinary, return a new observation handler populated with new observation component objects.
|
||||
|
||||
The expected format for the config dictionary is:
|
||||
|
||||
..code-block::python
|
||||
config = {
|
||||
components: [
|
||||
{
|
||||
"name": "<COMPONENT1_NAME>"
|
||||
},
|
||||
{
|
||||
"name": "<COMPONENT2_NAME>"
|
||||
"options": {"opt1": val1, "opt2": val2}
|
||||
},
|
||||
{
|
||||
...
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
:return: Observation handler
|
||||
:rtype: primaite.environment.observations.ObservationsHandler
|
||||
"""
|
||||
# Instantiate the handler
|
||||
handler = cls()
|
||||
|
||||
for component_cfg in obs_space_config["components"]:
|
||||
# Figure out which class can instantiate the desired component
|
||||
comp_type = component_cfg["name"]
|
||||
comp_class = cls._REGISTRY[comp_type]
|
||||
|
||||
# Create the component with options from the YAML
|
||||
options = component_cfg.get("options") or {}
|
||||
component = comp_class(env, **options)
|
||||
|
||||
handler.register(component)
|
||||
|
||||
handler.update_obs()
|
||||
return handler
|
||||
@@ -15,6 +15,7 @@ from gym import Env, spaces
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import (
|
||||
ActionType,
|
||||
@@ -23,11 +24,11 @@ from primaite.common.enums import (
|
||||
NodePOLInitiator,
|
||||
NodePOLType,
|
||||
NodeType,
|
||||
ObservationType,
|
||||
Priority,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.service import Service
|
||||
from primaite.environment.observations import ObservationsHandler
|
||||
from primaite.environment.reward import calculate_reward_function
|
||||
from primaite.links.link import Link
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
@@ -42,19 +43,17 @@ from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_nod
|
||||
from primaite.transactions.transaction import Transaction
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
_LOGGER.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class Primaite(Env):
|
||||
"""PRIMmary AI Training Evironment (Primaite) class."""
|
||||
|
||||
# Observation / Action Space contants
|
||||
OBSERVATION_SPACE_FIXED_PARAMETERS = 4
|
||||
ACTION_SPACE_NODE_PROPERTY_VALUES = 5
|
||||
ACTION_SPACE_NODE_ACTION_VALUES = 4
|
||||
ACTION_SPACE_ACL_ACTION_VALUES = 3
|
||||
ACTION_SPACE_ACL_PERMISSION_VALUES = 2
|
||||
|
||||
OBSERVATION_SPACE_HIGH_VALUE = 1000000 # Highest value within an observation space
|
||||
# Action Space contants
|
||||
ACTION_SPACE_NODE_PROPERTY_VALUES: int = 5
|
||||
ACTION_SPACE_NODE_ACTION_VALUES: int = 4
|
||||
ACTION_SPACE_ACL_ACTION_VALUES: int = 3
|
||||
ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2
|
||||
|
||||
def __init__(self, _config_values, _transaction_list):
|
||||
"""
|
||||
@@ -149,8 +148,14 @@ class Primaite(Env):
|
||||
# The action type
|
||||
self.action_type = 0
|
||||
|
||||
# Observation type, by default box.
|
||||
self.observation_type = ObservationType.BOX
|
||||
# stores the observation config from the yaml, default is NODE_LINK_TABLE
|
||||
self.obs_config: dict = {"components": [{"name": "NODE_LINK_TABLE"}]}
|
||||
if self.config_values.observation_config is not None:
|
||||
self.obs_config = self.config_values.observation_config
|
||||
|
||||
# Observation Handler manages the user-configurable observation space.
|
||||
# It will be initialised later.
|
||||
self.obs_handler: ObservationsHandler
|
||||
|
||||
# Open the config file and build the environment laydown
|
||||
try:
|
||||
@@ -202,15 +207,9 @@ class Primaite(Env):
|
||||
# [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa
|
||||
# [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa
|
||||
# [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa
|
||||
self.action_space = spaces.MultiDiscrete(
|
||||
[
|
||||
self.num_nodes,
|
||||
self.ACTION_SPACE_NODE_PROPERTY_VALUES,
|
||||
self.ACTION_SPACE_NODE_ACTION_VALUES,
|
||||
self.num_services,
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.action_dict = self.create_node_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||
elif self.action_type == ActionType.ACL:
|
||||
_LOGGER.info("Action space type ACL selected")
|
||||
# Terms (for ACL action space):
|
||||
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
|
||||
@@ -219,17 +218,14 @@ class Primaite(Env):
|
||||
# [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
|
||||
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
|
||||
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
|
||||
self.action_space = spaces.MultiDiscrete(
|
||||
[
|
||||
self.ACTION_SPACE_ACL_ACTION_VALUES,
|
||||
self.ACTION_SPACE_ACL_PERMISSION_VALUES,
|
||||
self.num_nodes + 1,
|
||||
self.num_nodes + 1,
|
||||
self.num_services + 1,
|
||||
self.num_ports + 1,
|
||||
]
|
||||
)
|
||||
|
||||
self.action_dict = self.create_acl_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||
elif self.action_type == ActionType.ANY:
|
||||
_LOGGER.info("Action space type ANY selected - Node + ACL")
|
||||
self.action_dict = self.create_node_and_acl_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||
else:
|
||||
_LOGGER.info("Invalid action type selected")
|
||||
# Set up a csv to store the results of the training
|
||||
try:
|
||||
now = datetime.now() # current date and time
|
||||
@@ -368,14 +364,14 @@ class Primaite(Env):
|
||||
# 5. Calculate reward signal (for RL)
|
||||
reward = calculate_reward_function(
|
||||
self.nodes_post_pol,
|
||||
self.nodes_post_blue,
|
||||
self.nodes_post_red,
|
||||
self.nodes_reference,
|
||||
self.green_iers,
|
||||
self.red_iers,
|
||||
self.step_count,
|
||||
self.config_values,
|
||||
)
|
||||
# print(f" Step {self.step_count} Reward: {str(reward)}")
|
||||
print(f" Step {self.step_count} Reward: {str(reward)}")
|
||||
self.total_reward += reward
|
||||
if self.step_count == self.episode_steps:
|
||||
self.average_reward = self.total_reward / self.step_count
|
||||
@@ -432,8 +428,18 @@ class Primaite(Env):
|
||||
# At the moment, actions are only affecting nodes
|
||||
if self.action_type == ActionType.NODE:
|
||||
self.apply_actions_to_nodes(_action)
|
||||
else:
|
||||
elif self.action_type == ActionType.ACL:
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif (
|
||||
len(self.action_dict[_action]) == 6
|
||||
): # ACL actions in multidiscrete form have len 6
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif (
|
||||
len(self.action_dict[_action]) == 4
|
||||
): # Node actions in multdiscrete (array) from have len 4
|
||||
self.apply_actions_to_nodes(_action)
|
||||
else:
|
||||
logging.error("Invalid action type found")
|
||||
|
||||
def apply_actions_to_nodes(self, _action):
|
||||
"""
|
||||
@@ -442,10 +448,11 @@ class Primaite(Env):
|
||||
Args:
|
||||
_action: The action space from the agent
|
||||
"""
|
||||
node_id = _action[0]
|
||||
node_property = _action[1]
|
||||
property_action = _action[2]
|
||||
service_index = _action[3]
|
||||
readable_action = self.action_dict[_action]
|
||||
node_id = readable_action[0]
|
||||
node_property = readable_action[1]
|
||||
property_action = readable_action[2]
|
||||
service_index = readable_action[3]
|
||||
|
||||
# Check that the action is requesting a valid node
|
||||
try:
|
||||
@@ -531,12 +538,15 @@ class Primaite(Env):
|
||||
Args:
|
||||
_action: The action space from the agent
|
||||
"""
|
||||
action_decision = _action[0]
|
||||
action_permission = _action[1]
|
||||
action_source_ip = _action[2]
|
||||
action_destination_ip = _action[3]
|
||||
action_protocol = _action[4]
|
||||
action_port = _action[5]
|
||||
# Convert discrete value back to multidiscrete
|
||||
readable_action = self.action_dict[_action]
|
||||
|
||||
action_decision = readable_action[0]
|
||||
action_permission = readable_action[1]
|
||||
action_source_ip = readable_action[2]
|
||||
action_destination_ip = readable_action[3]
|
||||
action_protocol = readable_action[4]
|
||||
action_port = readable_action[5]
|
||||
|
||||
if action_decision == 0:
|
||||
# It's decided to do nothing
|
||||
@@ -641,252 +651,20 @@ class Primaite(Env):
|
||||
else:
|
||||
pass
|
||||
|
||||
def _init_box_observations(self) -> Tuple[spaces.Space, np.ndarray]:
|
||||
"""Initialise the observation space with the BOX option chosen.
|
||||
|
||||
This will create the observation space formatted as a table of integers.
|
||||
There is one row per node, followed by one row per link.
|
||||
Columns are as follows:
|
||||
* node/link ID
|
||||
* node hardware status / 0 for links
|
||||
* node operating system status (if active/service) / 0 for links
|
||||
* node file system status (active/service only) / 0 for links
|
||||
* node service1 status / traffic load from that service for links
|
||||
* node service2 status / traffic load from that service for links
|
||||
* ...
|
||||
* node serviceN status / traffic load from that service for links
|
||||
|
||||
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
|
||||
``(12, 7)``
|
||||
|
||||
:return: Box gym observation
|
||||
:rtype: gym.spaces.Box
|
||||
:return: Initial observation with all entires set to 0
|
||||
:rtype: numpy.Array
|
||||
"""
|
||||
_LOGGER.info("Observation space type BOX selected")
|
||||
|
||||
# 1. Determine observation shape from laydown
|
||||
num_items = self.num_links + self.num_nodes
|
||||
num_observation_parameters = (
|
||||
self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS
|
||||
)
|
||||
observation_shape = (num_items, num_observation_parameters)
|
||||
|
||||
# 2. Create observation space & zeroed out sample from space.
|
||||
observation_space = spaces.Box(
|
||||
low=0,
|
||||
high=self.OBSERVATION_SPACE_HIGH_VALUE,
|
||||
shape=observation_shape,
|
||||
dtype=np.int64,
|
||||
)
|
||||
initial_observation = np.zeros(observation_shape, dtype=np.int64)
|
||||
|
||||
return observation_space, initial_observation
|
||||
|
||||
def _init_multidiscrete_observations(self) -> Tuple[spaces.Space, np.ndarray]:
|
||||
"""Initialise the observation space with the MULTIDISCRETE option chosen.
|
||||
|
||||
This will create the observation space with node observations followed by link observations.
|
||||
Each node has 3 elements in the observation space plus 1 per service, more specifically:
|
||||
* hardware state
|
||||
* operating system state
|
||||
* file system state
|
||||
* service states (one per service)
|
||||
Each link has one element in the observation space, corresponding to the traffic load,
|
||||
it can take the following values:
|
||||
0 = No traffic (0% of bandwidth)
|
||||
1 = No traffic (0%-33% of bandwidth)
|
||||
2 = No traffic (33%-66% of bandwidth)
|
||||
3 = No traffic (66%-100% of bandwidth)
|
||||
4 = No traffic (100% of bandwidth)
|
||||
|
||||
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
|
||||
``(37,)``
|
||||
|
||||
:return: MultiDiscrete gym observation
|
||||
:rtype: gym.spaces.MultiDiscrete
|
||||
:return: Initial observation with all entires set to 0
|
||||
:rtype: numpy.Array
|
||||
"""
|
||||
_LOGGER.info("Observation space MULTIDISCRETE selected")
|
||||
|
||||
# 1. Determine observation shape from laydown
|
||||
node_obs_shape = [
|
||||
len(HardwareState) + 1,
|
||||
len(SoftwareState) + 1,
|
||||
len(FileSystemState) + 1,
|
||||
]
|
||||
node_services = [len(SoftwareState) + 1] * self.num_services
|
||||
node_obs_shape = node_obs_shape + node_services
|
||||
# the magic number 5 refers to 5 states of quantisation of traffic amount.
|
||||
# (zero, low, medium, high, fully utilised/overwhelmed)
|
||||
link_obs_shape = [5] * self.num_links
|
||||
observation_shape = node_obs_shape * self.num_nodes + link_obs_shape
|
||||
|
||||
# 2. Create observation space & zeroed out sample from space.
|
||||
observation_space = spaces.MultiDiscrete(observation_shape)
|
||||
initial_observation = np.zeros(len(observation_shape), dtype=np.int64)
|
||||
|
||||
return observation_space, initial_observation
|
||||
|
||||
def init_observations(self) -> Tuple[spaces.Space, np.ndarray]:
|
||||
"""Build the observation space based on network laydown and provide initial obs.
|
||||
"""Create the environment's observation handler.
|
||||
|
||||
This method uses the object's `num_links`, `num_nodes`, `num_services`,
|
||||
`OBSERVATION_SPACE_FIXED_PARAMETERS`, `OBSERVATION_SPACE_HIGH_VALUE`, and `observation_type`
|
||||
attributes to figure out the correct shape and format for the observation space.
|
||||
|
||||
:raises ValueError: If the env's `observation_type` attribute is not set to a valid `enums.ObservationType`
|
||||
:return: Gym observation space
|
||||
:rtype: gym.spaces.Space
|
||||
:return: Initial observation with all entires set to 0
|
||||
:rtype: numpy.Array
|
||||
:return: The observation space, initial observation (zeroed out array with the correct shape)
|
||||
:rtype: Tuple[spaces.Space, np.ndarray]
|
||||
"""
|
||||
if self.observation_type == ObservationType.BOX:
|
||||
observation_space, initial_observation = self._init_box_observations()
|
||||
return observation_space, initial_observation
|
||||
elif self.observation_type == ObservationType.MULTIDISCRETE:
|
||||
(
|
||||
observation_space,
|
||||
initial_observation,
|
||||
) = self._init_multidiscrete_observations()
|
||||
return observation_space, initial_observation
|
||||
else:
|
||||
errmsg = (
|
||||
f"Observation type must be {ObservationType.BOX} or {ObservationType.MULTIDISCRETE}"
|
||||
f", got {self.observation_type} instead"
|
||||
)
|
||||
_LOGGER.error(errmsg)
|
||||
raise ValueError(errmsg)
|
||||
self.obs_handler = ObservationsHandler.from_config(self, self.obs_config)
|
||||
|
||||
def _update_env_obs_box(self):
|
||||
"""Update the environment's observation state based on the current status of nodes and links.
|
||||
|
||||
The structure of the observation space is described in :func:`~_init_box_observations`
|
||||
This function can only be called if the observation space setting is set to BOX.
|
||||
|
||||
:raises AssertionError: If this function is called when the environment has the incorrect ``observation_type``
|
||||
"""
|
||||
assert self.observation_type == ObservationType.BOX
|
||||
item_index = 0
|
||||
|
||||
# Do nodes first
|
||||
for node_key, node in self.nodes.items():
|
||||
self.env_obs[item_index][0] = int(node.node_id)
|
||||
self.env_obs[item_index][1] = node.hardware_state.value
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
self.env_obs[item_index][2] = node.software_state.value
|
||||
self.env_obs[item_index][3] = node.file_system_state_observed.value
|
||||
else:
|
||||
self.env_obs[item_index][2] = 0
|
||||
self.env_obs[item_index][3] = 0
|
||||
service_index = 4
|
||||
if isinstance(node, ServiceNode):
|
||||
for service in self.services_list:
|
||||
if node.has_service(service):
|
||||
self.env_obs[item_index][
|
||||
service_index
|
||||
] = node.get_service_state(service).value
|
||||
else:
|
||||
self.env_obs[item_index][service_index] = 0
|
||||
service_index += 1
|
||||
else:
|
||||
# Not a service node
|
||||
for service in self.services_list:
|
||||
self.env_obs[item_index][service_index] = 0
|
||||
service_index += 1
|
||||
item_index += 1
|
||||
|
||||
# Now do links
|
||||
for link_key, link in self.links.items():
|
||||
self.env_obs[item_index][0] = int(link.get_id())
|
||||
self.env_obs[item_index][1] = 0
|
||||
self.env_obs[item_index][2] = 0
|
||||
self.env_obs[item_index][3] = 0
|
||||
protocol_list = link.get_protocol_list()
|
||||
protocol_index = 0
|
||||
for protocol in protocol_list:
|
||||
self.env_obs[item_index][protocol_index + 4] = protocol.get_load()
|
||||
protocol_index += 1
|
||||
item_index += 1
|
||||
|
||||
def _update_env_obs_multidiscrete(self):
|
||||
"""Update the environment's observation state based on the current status of nodes and links.
|
||||
|
||||
The structure of the observation space is described in :func:`~_init_multidiscrete_observations`
|
||||
This function can only be called if the observation space setting is set to MULTIDISCRETE.
|
||||
|
||||
:raises AssertionError: If this function is called when the environment has the incorrect ``observation_type``
|
||||
"""
|
||||
assert self.observation_type == ObservationType.MULTIDISCRETE
|
||||
obs = []
|
||||
# 1. Set nodes
|
||||
# Each node has the following variables in the observation space:
|
||||
# - Hardware state
|
||||
# - Software state
|
||||
# - File System state
|
||||
# - Service 1 state
|
||||
# - Service 2 state
|
||||
# - ...
|
||||
# - Service N state
|
||||
for node_key, node in self.nodes.items():
|
||||
hardware_state = node.hardware_state.value
|
||||
software_state = 0
|
||||
file_system_state = 0
|
||||
services_states = [0] * self.num_services
|
||||
|
||||
if isinstance(
|
||||
node, ActiveNode
|
||||
): # ServiceNode is a subclass of ActiveNode so no need to check that also
|
||||
software_state = node.software_state.value
|
||||
file_system_state = node.file_system_state_observed.value
|
||||
|
||||
if isinstance(node, ServiceNode):
|
||||
for i, service in enumerate(self.services_list):
|
||||
if node.has_service(service):
|
||||
services_states[i] = node.get_service_state(service).value
|
||||
|
||||
obs.extend(
|
||||
[
|
||||
hardware_state,
|
||||
software_state,
|
||||
file_system_state,
|
||||
*services_states,
|
||||
]
|
||||
)
|
||||
|
||||
# 2. Set links
|
||||
# Each link has just one variable in the observation space, it represents the traffic amount
|
||||
# In order for the space to be fully MultiDiscrete, the amount of
|
||||
# traffic on each link is quantised into a few levels:
|
||||
# 0: no traffic (0% of bandwidth)
|
||||
# 1: low traffic (0-33% of bandwidth)
|
||||
# 2: medium traffic (33-66% of bandwidth)
|
||||
# 3: high traffic (66-100% of bandwidth)
|
||||
# 4: max traffic/overloaded (100% of bandwidth)
|
||||
|
||||
for link_key, link in self.links.items():
|
||||
bandwidth = link.bandwidth
|
||||
load = link.get_current_load()
|
||||
|
||||
if load <= 0:
|
||||
traffic_level = 0
|
||||
elif load >= bandwidth:
|
||||
traffic_level = 4
|
||||
else:
|
||||
traffic_level = (load / bandwidth) // (1 / 3) + 1
|
||||
|
||||
obs.append(int(traffic_level))
|
||||
|
||||
self.env_obs = np.asarray(obs)
|
||||
return self.obs_handler.space, self.obs_handler.current_observation
|
||||
|
||||
def update_environent_obs(self):
|
||||
"""Updates the observation space based on the node and link status."""
|
||||
if self.observation_type == ObservationType.BOX:
|
||||
self._update_env_obs_box()
|
||||
elif self.observation_type == ObservationType.MULTIDISCRETE:
|
||||
self._update_env_obs_multidiscrete()
|
||||
self.obs_handler.update_obs()
|
||||
self.env_obs = self.obs_handler.current_observation
|
||||
|
||||
def load_config(self):
|
||||
"""Loads config data in order to build the environment configuration."""
|
||||
@@ -921,9 +699,6 @@ class Primaite(Env):
|
||||
elif item["itemType"] == "ACTIONS":
|
||||
# Get the action information
|
||||
self.get_action_info(item)
|
||||
elif item["itemType"] == "OBSERVATIONS":
|
||||
# Get the observation information
|
||||
self.get_observation_info(item)
|
||||
elif item["itemType"] == "STEPS":
|
||||
# Get the steps information
|
||||
self.get_steps_info(item)
|
||||
@@ -1256,13 +1031,17 @@ class Primaite(Env):
|
||||
"""
|
||||
self.action_type = ActionType[action_info["type"]]
|
||||
|
||||
def get_observation_info(self, observation_info):
|
||||
"""Extracts observation_info.
|
||||
def save_obs_config(self, obs_config: dict):
|
||||
"""Cache the config for the observation space.
|
||||
|
||||
:param observation_info: Config item that defines which type of observation space to use
|
||||
:type observation_info: str
|
||||
This is necessary as the observation space can't be built while reading the config,
|
||||
it must be done after all the nodes, links, and services have been initialised.
|
||||
|
||||
:param obs_config: Parsed config relating to the observation space. The format is described in
|
||||
:py:meth:`primaite.environment.observations.ObservationsHandler.from_config`
|
||||
:type obs_config: dict
|
||||
"""
|
||||
self.observation_type = ObservationType[observation_info["type"]]
|
||||
self.obs_config = obs_config
|
||||
|
||||
def get_steps_info(self, steps_info):
|
||||
"""
|
||||
@@ -1347,3 +1126,91 @@ class Primaite(Env):
|
||||
else:
|
||||
# Bad formatting
|
||||
pass
|
||||
|
||||
def create_node_action_dict(self):
|
||||
"""
|
||||
Creates a dictionary mapping each possible discrete action to more readable multidiscrete action.
|
||||
|
||||
Note: Only actions that have the potential to change the state exist in the mapping (except for key 0)
|
||||
|
||||
example return:
|
||||
{0: [1, 0, 0, 0],
|
||||
1: [1, 1, 1, 0],
|
||||
2: [1, 1, 2, 0],
|
||||
3: [1, 1, 3, 0],
|
||||
4: [1, 2, 1, 0],
|
||||
5: [1, 3, 1, 0],
|
||||
...
|
||||
}
|
||||
|
||||
"""
|
||||
# reserve 0 action to be a nothing action
|
||||
actions = {0: [1, 0, 0, 0]}
|
||||
action_key = 1
|
||||
for node in range(1, self.num_nodes + 1):
|
||||
# 4 node properties (NONE, OPERATING, OS, SERVICE)
|
||||
for node_property in range(4):
|
||||
# Node Actions either:
|
||||
# (NONE, ON, OFF, RESET) - operating state OR (NONE, PATCH) - OS/service state
|
||||
# Use MAX to ensure we get them all
|
||||
for node_action in range(4):
|
||||
for service_state in range(self.num_services):
|
||||
action = [node, node_property, node_action, service_state]
|
||||
# check to see if it's a nothing action (has no effect)
|
||||
if is_valid_node_action(action):
|
||||
actions[action_key] = action
|
||||
action_key += 1
|
||||
|
||||
return actions
|
||||
|
||||
def create_acl_action_dict(self):
|
||||
"""Creates a dictionary mapping each possible discrete action to more readable multidiscrete action."""
|
||||
# reserve 0 action to be a nothing action
|
||||
actions = {0: [0, 0, 0, 0, 0, 0]}
|
||||
|
||||
action_key = 1
|
||||
# 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE
|
||||
for action_decision in range(3):
|
||||
# 2 possible action permissions 0 = DENY, 1 = CREATE
|
||||
for action_permission in range(2):
|
||||
# Number of nodes + 1 (for any)
|
||||
for source_ip in range(self.num_nodes + 1):
|
||||
for dest_ip in range(self.num_nodes + 1):
|
||||
for protocol in range(self.num_services + 1):
|
||||
for port in range(self.num_ports + 1):
|
||||
action = [
|
||||
action_decision,
|
||||
action_permission,
|
||||
source_ip,
|
||||
dest_ip,
|
||||
protocol,
|
||||
port,
|
||||
]
|
||||
# Check to see if its an action we want to include as possible i.e. not a nothing action
|
||||
if is_valid_acl_action_extra(action):
|
||||
actions[action_key] = action
|
||||
action_key += 1
|
||||
|
||||
return actions
|
||||
|
||||
def create_node_and_acl_action_dict(self):
|
||||
"""
|
||||
Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action.
|
||||
|
||||
The dictionary contains actions of both Node and ACL action types.
|
||||
|
||||
"""
|
||||
node_action_dict = self.create_node_action_dict()
|
||||
acl_action_dict = self.create_acl_action_dict()
|
||||
|
||||
# Change node keys to not overlap with acl keys
|
||||
# Only 1 nothing action (key 0) is required, remove the other
|
||||
new_node_action_dict = {
|
||||
k + len(acl_action_dict) - 1: v
|
||||
for k, v in node_action_dict.items()
|
||||
if k != 0
|
||||
}
|
||||
|
||||
# Combine the Node dict and ACL dict
|
||||
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
|
||||
return combined_action_dict
|
||||
|
||||
@@ -24,6 +24,7 @@ from primaite.transactions.transactions_to_file import write_transaction_to_file
|
||||
def run_generic():
|
||||
"""Run against a generic agent."""
|
||||
for episode in range(0, config_values.num_episodes):
|
||||
env.reset()
|
||||
for step in range(0, config_values.num_steps):
|
||||
# Send the observation space to the agent to get an action
|
||||
# TEMP - random action for now
|
||||
@@ -41,7 +42,6 @@ def run_generic():
|
||||
time.sleep(config_values.time_delay / 1000)
|
||||
|
||||
# Reset the environment at the end of the episode
|
||||
env.reset()
|
||||
|
||||
env.close()
|
||||
|
||||
@@ -162,6 +162,10 @@ def load_config_values():
|
||||
try:
|
||||
# Generic
|
||||
config_values.agent_identifier = config_data["agentIdentifier"]
|
||||
if "observationSpace" in config_data:
|
||||
config_values.observation_config = config_data["observationSpace"]
|
||||
else:
|
||||
config_values.observation_config = None
|
||||
config_values.num_episodes = int(config_data["numEpisodes"])
|
||||
config_values.time_delay = int(config_data["timeDelay"])
|
||||
config_values.config_filename_use_case = (
|
||||
@@ -376,7 +380,7 @@ logging.info("Saving transaction logs...")
|
||||
|
||||
write_transaction_to_file(transaction_list)
|
||||
|
||||
config_file_main.close
|
||||
config_file_main.close()
|
||||
|
||||
print("Finished")
|
||||
logging.info("Finished")
|
||||
|
||||
Reference in New Issue
Block a user