Make game layer work with new state api

This commit is contained in:
Marek Wolan
2023-12-14 14:04:43 +00:00
parent 1ec7df1170
commit 6a80f4cc77
10 changed files with 145 additions and 196 deletions

1
.gitignore vendored
View File

@@ -156,3 +156,4 @@ benchmark/output
# src/primaite/notebooks/scratch.ipynb
src/primaite/notebooks/scratch.py
sandbox.py
sandbox.ipynb

View File

@@ -105,25 +105,25 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: DNSServer
- node_hostname: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_ref: database_service
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
- node_hostname: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -138,7 +138,7 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
@@ -509,7 +509,7 @@ agents:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
@@ -517,8 +517,8 @@ agents:
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: web_server
service_name: web_server_web_service
agent_settings:

View File

@@ -40,8 +40,7 @@ class AbstractObservation(ABC):
def from_config(cls, config: Dict, game: "PrimaiteGame"):
"""Create this observation space component form a serialised format.
The `game` parameter is for a the PrimaiteGame object that spawns this component. During deserialisation,
a subclass of this class may need to translate from a 'reference' to a UUID.
The `game` parameter is for a the PrimaiteGame object that spawns this component.
"""
pass
@@ -53,12 +52,12 @@ class FileObservation(AbstractObservation):
"""
Initialise file observation.
:param where: Store information about where in the simulation state dictionary to find the relevatn information.
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a file looks like this:
['network','nodes',<node_uuid>,'file_system', 'folders',<folder_name>,'files',<file_name>]
['network','nodes',<node_hostname>,'file_system', 'folders',<folder_name>,'files',<file_name>]
:type where: Optional[List[str]]
"""
super().__init__()
@@ -120,7 +119,7 @@ class ServiceObservation(AbstractObservation):
zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_uuid>,'services', <service_uuid>]`
`['network','nodes',<node_hostname>,'services', <service_name>]`
:type where: Optional[List[str]]
"""
super().__init__()
@@ -162,7 +161,7 @@ class ServiceObservation(AbstractObservation):
:return: Constructed service observation
:rtype: ServiceObservation
"""
return cls(where=parent_where + ["services", game.ref_map_services[config["service_ref"]]])
return cls(where=parent_where + ["services", config["service_name"]])
class LinkObservation(AbstractObservation):
@@ -179,7 +178,7 @@ class LinkObservation(AbstractObservation):
zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_uuid>,'servics', <service_uuid>]`
`['network','nodes',<node_hostname>,'servics', <service_name>]`
:type where: Optional[List[str]]
"""
super().__init__()
@@ -242,7 +241,7 @@ class FolderObservation(AbstractObservation):
:param where: Where in the simulation state dictionary to find the relevant information for this folder.
A typical location for a file looks like this:
['network','nodes',<node_uuid>,'file_system', 'folders',<folder_name>]
['network','nodes',<node_hostname>,'file_system', 'folders',<folder_name>]
:type where: Optional[List[str]]
:param max_files: As size of the space must remain static, define max files that can be in this folder
, defaults to 5
@@ -321,7 +320,7 @@ class FolderObservation(AbstractObservation):
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this folder's
parent node. A typical location for a node ``where`` can be:
['network','nodes',<node_uuid>,'file_system']
['network','nodes',<node_hostname>,'file_system']
:type parent_where: Optional[List[str]]
:param num_files_per_folder: How many spaces for files are in this folder observation (to preserve static
observation size) , defaults to 2
@@ -347,7 +346,7 @@ class NicObservation(AbstractObservation):
:param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical
example may look like this:
['network','nodes',<node_uuid>,'NICs',<nic_uuid>]
['network','nodes',<node_hostname>,'NICs',<nic_index>]
If None, this denotes that the NIC does not exist and the observation will be populated with zeroes.
:type where: Optional[Tuple[str]], optional
"""
@@ -384,12 +383,12 @@ class NicObservation(AbstractObservation):
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent
node. A typical location for a node ``where`` can be: ['network','nodes',<node_uuid>]
node. A typical location for a node ``where`` can be: ['network','nodes',<node_hostname>]
:type parent_where: Optional[List[str]]
:return: Constructed NIC observation
:rtype: NicObservation
"""
return cls(where=parent_where + ["NICs", config["nic_uuid"]])
return cls(where=parent_where + ["NICs", config["nic_num"]])
class NodeObservation(AbstractObservation):
@@ -412,9 +411,9 @@ class NodeObservation(AbstractObservation):
:param where: Where in the simulation state dictionary for find relevant information for this observation.
A typical location for a node looks like this:
['network','nodes',<node_uuid>]. If empty list, a default null observation will be output, defaults to []
['network','nodes',<hostname>]. If empty list, a default null observation will be output, defaults to []
:type where: List[str], optional
:param services: Mapping between position in observation space and service UUID, defaults to {}
:param services: Mapping between position in observation space and service name, defaults to {}
:type services: Dict[int,str], optional
:param max_services: Max number of services that can be presented in observation space for this node
, defaults to 2
@@ -423,7 +422,7 @@ class NodeObservation(AbstractObservation):
:type folders: Dict[int,str], optional
:param max_folders: Max number of folders in this node's obs space, defaults to 2
:type max_folders: int, optional
:param nics: Mapping between position in observation space and NIC UUID, defaults to {}
:param nics: Mapping between position in observation space and NIC idx, defaults to {}
:type nics: Dict[int,str], optional
:param max_nics: Max number of NICS in this node's obs space, defaults to 5
:type max_nics: int, optional
@@ -541,11 +540,11 @@ class NodeObservation(AbstractObservation):
:return: Constructed node observation
:rtype: NodeObservation
"""
node_uuid = game.ref_map_nodes[config["node_ref"]]
node_hostname = config["node_hostname"]
if parent_where is None:
where = ["network", "nodes", node_uuid]
where = ["network", "nodes", node_hostname]
else:
where = parent_where + ["nodes", node_uuid]
where = parent_where + ["nodes", node_hostname]
svc_configs = config.get("services", {})
services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs]
@@ -556,8 +555,8 @@ class NodeObservation(AbstractObservation):
)
for c in folder_configs
]
nic_uuids = game.simulation.network.nodes[node_uuid].nics.keys()
nic_configs = [{"nic_uuid": n for n in nic_uuids}] if nic_uuids else []
# create some configs for the NIC observation in the format {"nic_num":1}, {"nic_num":2}, {"nic_num":3}, etc.
nic_configs = [{"nic_num": i for i in range(num_nics_per_node)}]
nics = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs]
logon_status = config.get("logon_status", False)
return cls(
@@ -598,7 +597,7 @@ class AclObservation(AbstractObservation):
:type protocols: list[str]
:param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical
example may look like this:
['network','nodes',<router_uuid>,'acl','acl']
['network','nodes',<router_hostname>,'acl','acl']
:type where: Optional[Tuple[str]], optional
:param num_rules: , defaults to 10
:type num_rules: int, optional
@@ -711,12 +710,12 @@ class AclObservation(AbstractObservation):
nic_obj = node_obj.ethernet_port[nic_num]
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
router_uuid = game.ref_map_nodes[config["router_node_ref"]]
router_hostname = config["router_hostname"]
return cls(
node_ip_to_id=node_ip_to_idx,
ports=game.options.ports,
protocols=game.options.protocols,
where=["network", "nodes", router_uuid, "acl", "acl"],
where=["network", "nodes", router_hostname, "acl", "acl"],
num_rules=max_acl_rules,
)
@@ -846,6 +845,7 @@ class UC2BlueObservation(AbstractObservation):
:rtype: UC2BlueObservation
"""
node_configs = config["nodes"]
num_services_per_node = config["num_services_per_node"]
num_folders_per_node = config["num_folders_per_node"]
num_files_per_folder = config["num_files_per_folder"]

View File

@@ -82,11 +82,11 @@ class DummyReward(AbstractReward):
class DatabaseFileIntegrity(AbstractReward):
"""Reward function component which rewards the agent for maintaining the integrity of a database file."""
def __init__(self, node_uuid: str, folder_name: str, file_name: str) -> None:
def __init__(self, node_hostname: str, folder_name: str, file_name: str) -> None:
"""Initialise the reward component.
:param node_uuid: UUID of the node which contains the database file.
:type node_uuid: str
:param node_hostname: Hostname of the node which contains the database file.
:type node_hostname: str
:param folder_name: folder which contains the database file.
:type folder_name: str
:param file_name: name of the database file.
@@ -95,7 +95,7 @@ class DatabaseFileIntegrity(AbstractReward):
self.location_in_state = [
"network",
"nodes",
node_uuid,
node_hostname,
"file_system",
"folders",
folder_name,
@@ -129,49 +129,29 @@ class DatabaseFileIntegrity(AbstractReward):
:return: The reward component.
:rtype: DatabaseFileIntegrity
"""
node_ref = config.get("node_ref")
node_hostname = config.get("node_hostname")
folder_name = config.get("folder_name")
file_name = config.get("file_name")
if not node_ref:
_LOGGER.error(
f"{cls.__name__} could not be initialised from config because node_ref parameter was not specified"
)
return DummyReward() # TODO: better error handling
if not folder_name:
_LOGGER.error(
f"{cls.__name__} could not be initialised from config because folder_name parameter was not specified"
)
return DummyReward() # TODO: better error handling
if not file_name:
_LOGGER.error(
f"{cls.__name__} could not be initialised from config because file_name parameter was not specified"
)
return DummyReward() # TODO: better error handling
node_uuid = game.ref_map_nodes[node_ref]
if not node_uuid:
_LOGGER.error(
(
f"{cls.__name__} could not be initialised from config because the referenced node could not be "
f"found in the simulation"
)
)
return DummyReward() # TODO: better error handling
if not node_hostname and folder_name and file_name:
msg = f"{cls.__name__} could not be initialised with parameters {config}"
_LOGGER.error(msg)
raise ValueError(msg)
return cls(node_uuid=node_uuid, folder_name=folder_name, file_name=file_name)
return cls(node_hostname=node_hostname, folder_name=folder_name, file_name=file_name)
class WebServer404Penalty(AbstractReward):
"""Reward function component which penalises the agent when the web server returns a 404 error."""
def __init__(self, node_uuid: str, service_uuid: str) -> None:
def __init__(self, node_hostname: str, service_name: str) -> None:
"""Initialise the reward component.
:param node_uuid: UUID of the node which contains the web server service.
:type node_uuid: str
:param service_uuid: UUID of the web server service.
:type service_uuid: str
:param node_hostname: Hostname of the node which contains the web server service.
:type node_hostname: str
:param service_node: Name of the web server service.
:type service_node: str
"""
self.location_in_state = ["network", "nodes", node_uuid, "services", service_uuid]
self.location_in_state = ["network", "nodes", node_hostname, "services", service_name]
def calculate(self, state: Dict) -> float:
"""Calculate the reward for the current state.
@@ -203,26 +183,17 @@ class WebServer404Penalty(AbstractReward):
:return: The reward component.
:rtype: WebServer404Penalty
"""
node_ref = config.get("node_ref")
service_ref = config.get("service_ref")
if not (node_ref and service_ref):
node_hostname = config.get("node_hostname")
service_name = config.get("service_name")
if not (node_hostname and service_name):
msg = (
f"{cls.__name__} could not be initialised from config because node_ref and service_ref were not "
"found in reward config."
)
_LOGGER.warning(msg)
return DummyReward() # TODO: should we error out with incorrect inputs? Probably!
node_uuid = game.ref_map_nodes[node_ref]
service_uuid = game.ref_map_services[service_ref]
if not (node_uuid and service_uuid):
msg = (
f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not"
" found in the simulator."
)
_LOGGER.warning(msg)
return DummyReward() # TODO: consider erroring here as well
raise ValueError(msg)
return cls(node_uuid=node_uuid, service_uuid=service_uuid)
return cls(node_hostname=node_hostname, service_name=service_name)
class RewardFunction:

View File

@@ -93,25 +93,25 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: domain_controller_dns_server
- node_hostname: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_ref: database_service
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
- node_hostname: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -126,7 +126,7 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
@@ -497,7 +497,7 @@ agents:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
@@ -505,8 +505,8 @@ agents:
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: web_server
service_name: web_server_web_service
agent_settings:

View File

@@ -31,13 +31,6 @@ agents:
action_space:
action_list:
- type: DONOTHING
# <not yet implemented>
# - type: NODE_LOGON
# - type: NODE_LOGOFF
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# target_address: arcd.com
options:
nodes:
@@ -104,25 +97,25 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: domain_controller_dns_server
- node_hostname: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_ref: database_service
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
- node_hostname: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -137,7 +130,7 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
@@ -508,7 +501,7 @@ agents:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
@@ -516,8 +509,8 @@ agents:
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: web_server
service_name: web_server_web_service
agent_settings:

View File

@@ -37,13 +37,6 @@ agents:
action_space:
action_list:
- type: DONOTHING
# <not yet implemented>
# - type: NODE_LOGON
# - type: NODE_LOGOFF
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# target_address: arcd.com
options:
nodes:
@@ -111,25 +104,25 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: domain_controller_dns_server
- node_hostname: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_ref: database_service
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
- node_hostname: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -144,7 +137,7 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
@@ -515,7 +508,7 @@ agents:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
@@ -523,8 +516,8 @@ agents:
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: web_server
service_name: web_server_web_service
agent_settings:
@@ -542,25 +535,25 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: domain_controller_dns_server
- node_hostname: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_ref: database_service
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
- node_hostname: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -575,7 +568,7 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
@@ -946,7 +939,7 @@ agents:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
@@ -954,8 +947,8 @@ agents:
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: web_server
service_name: web_server_web_service
agent_settings:

View File

@@ -35,13 +35,6 @@ agents:
action_space:
action_list:
- type: DONOTHING
# <not yet implemented>
# - type: NODE_LOGON
# - type: NODE_LOGOFF
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# target_address: arcd.com
options:
nodes:
@@ -109,25 +102,25 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: domain_controller_dns_server
- node_hostname: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_ref: database_service
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
- node_hostname: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
# - service_name: backup_service
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -142,7 +135,7 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
@@ -513,7 +506,7 @@ agents:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
@@ -521,8 +514,8 @@ agents:
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: web_server
service_name: web_server_web_service
agent_settings:

View File

@@ -105,25 +105,23 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: domain_controller_dns_server
- node_hostname: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_ref: database_service
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- node_hostname: backup_server
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -138,7 +136,7 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
@@ -509,7 +507,7 @@ agents:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
@@ -517,8 +515,8 @@ agents:
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: web_server
service_name: web_service
agent_settings:

View File

@@ -14,7 +14,7 @@ def test_file_observation():
state = sim.describe_state()
dog_file_obs = FileObservation(
where=["network", "nodes", pc.uuid, "file_system", "folders", "root", "files", "dog.png"]
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"]
)
assert dog_file_obs.observe(state) == {"health_status": 1}
assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)})