901 - started testing for observation space
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Optional, Union
|
||||
@@ -9,6 +10,7 @@ from primaite import USERS_CONFIG_DIR, getLogger
|
||||
from primaite.common.enums import ActionType
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
logging.basicConfig(level=logging.DEBUG, format="%(message)s")
|
||||
|
||||
_EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training"
|
||||
|
||||
|
||||
@@ -331,7 +331,6 @@ class AccessControlList(AbstractObservationComponent):
|
||||
"""
|
||||
|
||||
# Terms (for ACL observation space):
|
||||
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
|
||||
# [0, 1] - Permission (0 = DENY, 1 = ALLOW)
|
||||
# [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses)
|
||||
# [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
|
||||
@@ -352,7 +351,7 @@ class AccessControlList(AbstractObservationComponent):
|
||||
len(env.services_list),
|
||||
len(env.ports_list),
|
||||
]
|
||||
shape = acl_shape * self.env.max_acl_rules
|
||||
shape = acl_shape * self.env.max_number_acl_rules
|
||||
|
||||
# 2. Create Observation space
|
||||
self.space = spaces.MultiDiscrete(shape)
|
||||
@@ -360,9 +359,6 @@ class AccessControlList(AbstractObservationComponent):
|
||||
# 3. Initialise observation with zeroes
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
|
||||
# Dictionary to map services to numbers for obs space
|
||||
self.services_dict = {}
|
||||
|
||||
def update(self):
|
||||
"""Update the observation based on current environment state.
|
||||
|
||||
@@ -380,7 +376,6 @@ class AccessControlList(AbstractObservationComponent):
|
||||
permission_int = 0
|
||||
else:
|
||||
permission_int = 1
|
||||
|
||||
if source_ip == "ANY":
|
||||
source_ip_int = 0
|
||||
else:
|
||||
@@ -393,9 +388,10 @@ class AccessControlList(AbstractObservationComponent):
|
||||
protocol_int = 0
|
||||
else:
|
||||
try:
|
||||
protocol_int = Protocol[protocol]
|
||||
protocol_int = Protocol[protocol].value
|
||||
except AttributeError:
|
||||
_LOGGER.info(f"Service {protocol} could not be found")
|
||||
protocol_int = -1
|
||||
if port == "ANY":
|
||||
port_int = 0
|
||||
else:
|
||||
@@ -423,7 +419,8 @@ class AccessControlList(AbstractObservationComponent):
|
||||
|
||||
Resolves IP address -> x (node id e.g. 1 or 2 or 3 or 4) for observation space
|
||||
"""
|
||||
for key, node in self.env.nodes:
|
||||
print(type(self.env.nodes))
|
||||
for key, node in self.env.nodes.items():
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
if node.ip_address == ip_address:
|
||||
return key
|
||||
|
||||
@@ -122,6 +122,9 @@ class Primaite(Env):
|
||||
self.training_config.implicit_acl_rule,
|
||||
self.training_config.max_number_acl_rules,
|
||||
)
|
||||
# Sets limit for number of ACL rules in environment
|
||||
self.max_number_acl_rules = self.training_config.max_number_acl_rules
|
||||
|
||||
# Create a list of services (enums)
|
||||
self.services_list = []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user