End-of-day commit
This commit is contained in:
276
example_config.yaml
Normal file
276
example_config.yaml
Normal file
@@ -0,0 +1,276 @@
|
||||
training_config:
|
||||
rl_framework: SB3
|
||||
rl_algo: PPO
|
||||
n_learn_steps: 128
|
||||
n_learn_episodes: 1000
|
||||
|
||||
game_config:
|
||||
ports:
|
||||
- ARP
|
||||
- DNS
|
||||
- POSTGRES_SERVER
|
||||
protocols:
|
||||
- ICMP
|
||||
- TCP
|
||||
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
team: GREEN
|
||||
team: SCRIPTED_GREEN_<class>
|
||||
observation_space:
|
||||
...
|
||||
action_space:
|
||||
...
|
||||
reward_function:
|
||||
- type: null_reward
|
||||
# node_ref: client_1
|
||||
# service: WebBrowser
|
||||
# pol:
|
||||
# - step: 1
|
||||
# action: START
|
||||
|
||||
- ref: client_1_data_manipulation_red_bot
|
||||
team: RED
|
||||
type: SCRIPTED_RED_<class>
|
||||
observation_space:
|
||||
network:
|
||||
nodes:
|
||||
- ref: client_1
|
||||
- logon_status
|
||||
- operating_status
|
||||
services:
|
||||
- ref: data_manipulation_bot
|
||||
- operating_status
|
||||
- health_status
|
||||
folders:
|
||||
files: {}
|
||||
nics: {}
|
||||
|
||||
action_space:
|
||||
actions:
|
||||
- DO_NOTHING
|
||||
network:
|
||||
nodes:
|
||||
- ref: client_1
|
||||
actions:
|
||||
- SCAN
|
||||
- LOGON
|
||||
- LOGOFF
|
||||
services:
|
||||
- ref: data_manipulation_bot
|
||||
actions:
|
||||
- type: COMPROMISE
|
||||
execution_definition:
|
||||
server_ip: 192.168.1.14
|
||||
payload: "DROP TABLE IF EXISTS user;"
|
||||
success_rate: 80%
|
||||
folders:
|
||||
files: {}
|
||||
reward_function: null
|
||||
options: # options specific to this particular agent type, basically args of __init__(self)
|
||||
start_step: 25
|
||||
frequency: 20
|
||||
variance: 5
|
||||
|
||||
|
||||
|
||||
|
||||
- ref: defender
|
||||
team: blue
|
||||
type: GATE_RL_AGENT
|
||||
observation_space:
|
||||
network:
|
||||
nodes:
|
||||
- ref: <noderef>
|
||||
action_space:
|
||||
...
|
||||
reward_function:
|
||||
...
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nodes:
|
||||
|
||||
- ref: router_1
|
||||
type: router
|
||||
hostname: router_1
|
||||
num_ports: 5
|
||||
ports:
|
||||
1:
|
||||
ip_address: 192.168.1.1
|
||||
subnet_mask: 255.255.255.0
|
||||
2:
|
||||
ip_address: 192.168.1.1
|
||||
subnet_mask: 255.255.255.0
|
||||
acl:
|
||||
0:
|
||||
action: PERMIT
|
||||
src_port: POSTGRES_SERVER
|
||||
dst_port: POSTGRES_SERVER
|
||||
1:
|
||||
action: PERMIT
|
||||
src_port: DNS
|
||||
dst_port: DNS
|
||||
22:
|
||||
action: PERMIT
|
||||
src_port: ARP
|
||||
dst_port: ARP
|
||||
23:
|
||||
action: PERMIT
|
||||
protocol: ICMP
|
||||
|
||||
- ref: switch_1
|
||||
type: swtich
|
||||
hostname: switch_1
|
||||
num_ports: 8
|
||||
|
||||
- ref: switch_2
|
||||
type: switch
|
||||
hostname: switch_2
|
||||
num_ports: 8
|
||||
|
||||
- ref: domain_controller
|
||||
type: server
|
||||
hostname: domain_controller
|
||||
ip_address: 192.168.1.10
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
services:
|
||||
- ref: domain_controller_dns_server
|
||||
type: dns_server
|
||||
options:
|
||||
domain_mapping:
|
||||
- arcd.com: 192.168.1.12 # web server
|
||||
|
||||
|
||||
- ref: web_server
|
||||
type: server
|
||||
hostname: web_server
|
||||
ip_address: 192.168.1.12
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.10
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: web_server_database_client
|
||||
type: database_client
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
|
||||
- ref: database_server
|
||||
type: server
|
||||
hostname: database_server
|
||||
ip_address: 192.168.1.14
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: database_service
|
||||
type: database_service
|
||||
|
||||
|
||||
- ref: backup_server
|
||||
type: node
|
||||
hostname: backup_server
|
||||
ip_address: 192.168.1.16
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: backup_service
|
||||
type: database_backup
|
||||
|
||||
- ref: security_suite
|
||||
type: server
|
||||
hostname: security_suite
|
||||
ip_address: 192.168.1.110
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
nics:
|
||||
2:
|
||||
ip_address: 192.168.10.110
|
||||
subnet_mask: 255.255.255.0
|
||||
|
||||
|
||||
- ref: client_1
|
||||
type: computer
|
||||
hostname: client_1
|
||||
ip_address: 192.168.10.21.
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: data_manipulation_bot
|
||||
type: data_manipulation_bot
|
||||
- ref: client_1_dns_client
|
||||
type: dns_client
|
||||
|
||||
- ref: client_2
|
||||
type: computer
|
||||
hostname: client_2
|
||||
ip_address: 192.168.10.22
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: web_browser
|
||||
type: web_browser
|
||||
- ref: client_2_dns_client
|
||||
type: dns_client
|
||||
|
||||
|
||||
links:
|
||||
- ref: router_1___switch_1
|
||||
endpoint_a: router_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b: switch_1
|
||||
endpoint_b_port: 8
|
||||
- ref: router_1___switch_2
|
||||
endpoint_a: router_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b: switch_2
|
||||
endpoint_b_port: 8
|
||||
- ref: switch_1___domain_controller
|
||||
endpoint_a: switch_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b: domain_controller
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___web_server
|
||||
endpoint_a: switch_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b: web_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___database_server
|
||||
endpoint_a: switch_1
|
||||
endpoint_a_port: 3
|
||||
endpoint_b: database_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___backup_server
|
||||
endpoint_a: switch_1
|
||||
endpoint_a_port: 4
|
||||
endpoint_b: backup_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___security_suite
|
||||
endpoint_a: switch_1
|
||||
endpoint_a_port: 7
|
||||
endpoint_b: security_suite
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___client_1
|
||||
endpoint_a: switch_2
|
||||
endpoint_a_port: 1
|
||||
endpoint_b: client_1
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___client_2
|
||||
endpoint_a: switch_2
|
||||
endpoint_a_port: 2
|
||||
endpoint_b: client_2
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___security_suite
|
||||
endpoint_a: switch_2
|
||||
endpoint_a_port: 7
|
||||
endpoint_b: security_suite
|
||||
endpoint_b_port: 2
|
||||
@@ -11,10 +11,13 @@ from primaite.game.actor.observations import ObservationSpace
|
||||
from primaite.game.actor.rewards import RewardFunction
|
||||
|
||||
|
||||
class AbstractActor(BaseModel):
|
||||
class AbstractActor(ABC):
|
||||
"""Base class for scripted and RL agents."""
|
||||
|
||||
...
|
||||
def __init__(self) -> None:
|
||||
self.action_space = ActionSpace
|
||||
self.observation_space = ObservationSpace
|
||||
self.reward_function = RewardFunction
|
||||
|
||||
|
||||
class AbstractScriptedActor(AbstractActor):
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Hashable, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Dict, Hashable, List, Optional
|
||||
|
||||
from gym import spaces
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
|
||||
NOT_PRESENT_IN_STATE = object()
|
||||
"""
|
||||
Need an object to return when the sim state does not contain a requested value. Cannot use None because sometimes
|
||||
the thing requested in the state could equal None. This NOT_PRESENT_IN_STATE is a sentinel for this purpose.
|
||||
"""
|
||||
|
||||
|
||||
def access_from_nested_dict(dictionary: Dict, keys: List[Hashable]) -> Any:
|
||||
@@ -20,19 +27,17 @@ def access_from_nested_dict(dictionary: Dict, keys: List[Hashable]) -> Any:
|
||||
:return: The value in the dictionary
|
||||
:rtype: Any
|
||||
"""
|
||||
if not keys:
|
||||
if len(keys) == 0:
|
||||
return dictionary
|
||||
k = keys.pop(0)
|
||||
try:
|
||||
return access_from_nested_dict(dictionary[k], keys)
|
||||
except (TypeError, KeyError):
|
||||
raise KeyError(f"Cannot find requested key `{k}` in nested dictionary")
|
||||
if k not in dictionary:
|
||||
return NOT_PRESENT_IN_STATE
|
||||
return access_from_nested_dict(dictionary[k], keys)
|
||||
|
||||
|
||||
class AbstractObservation(BaseModel):
|
||||
|
||||
class AbstractObservation(ABC):
|
||||
@abstractmethod
|
||||
def __call__(self, state: Dict) -> Any:
|
||||
def observe(self, state: Dict) -> Any:
|
||||
"""_summary_
|
||||
|
||||
:param state: _description_
|
||||
@@ -41,7 +46,6 @@ class AbstractObservation(BaseModel):
|
||||
:rtype: Any
|
||||
"""
|
||||
...
|
||||
# receive state dict
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
@@ -51,72 +55,396 @@ class AbstractObservation(BaseModel):
|
||||
|
||||
|
||||
class FileObservation(AbstractObservation):
|
||||
where: List[str]
|
||||
"""Store information about where in the simulation state dictionary to find the relevatn information."""
|
||||
def __init__(self, where: List[str] = []) -> None:
|
||||
"""
|
||||
_summary_
|
||||
|
||||
def __call__(self, state: Dict) -> Dict:
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevatn information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a file looks like this:
|
||||
['network','nodes',<node_uuid>,'file_system', 'folders',<folder_name>,'files',<file_name>]
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: List[str] = where
|
||||
self.default_observation: spaces.Space = {"health_status": 0}
|
||||
"Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted."
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
if not self.where:
|
||||
return self.default_observation
|
||||
file_state = access_from_nested_dict(state, self.where)
|
||||
observation = {'health_status':file_state['health_status']}
|
||||
return observation
|
||||
if file_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
return {"health_status": file_state["health_status"]}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
return spaces.Dict({'health_status':spaces.Discrete(6)})
|
||||
return spaces.Dict({"health_status": spaces.Discrete(6)})
|
||||
|
||||
|
||||
class ServiceObservation(AbstractObservation):
|
||||
default_observation: spaces.Space = {"operating_status": 0, "health_status": 0}
|
||||
"Default observation is what should be returned when the service doesn't exist."
|
||||
|
||||
def __init__(self, where: List[str] = []) -> None:
|
||||
"""
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a service looks like this:
|
||||
`['network','nodes',<node_uuid>,'servics', <service_uuid>]`
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: List[str] = where
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
if not self.where:
|
||||
return self.default_observation
|
||||
|
||||
service_state = access_from_nested_dict(state, self.where)
|
||||
if service_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
return {"operating_status": service_state["operating_status"], "health_status": service_state["health_status"]}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(6)})
|
||||
|
||||
|
||||
class LinkObservation(AbstractObservation):
|
||||
default_observation: spaces.Space = {"protocols": {"all": {"load": 0}}}
|
||||
"Default observation is what should be returned when the link doesn't exist."
|
||||
|
||||
def __init__(self, where: List[str] = []) -> None:
|
||||
"""
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a service looks like this:
|
||||
`['network','nodes',<node_uuid>,'servics', <service_uuid>]`
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: List[str] = where
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
if not self.where:
|
||||
return self.default_observation
|
||||
|
||||
link_state = access_from_nested_dict(state, self.where)
|
||||
if link_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
bandwidth = link_state["bandwidth"]
|
||||
load = link_state["current_load"]
|
||||
utilisation_fraction = load / bandwidth
|
||||
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
|
||||
utilisation_category = int(utilisation_fraction * 10) + 1
|
||||
|
||||
# TODO: once the links support separte load per protocol, this needs amendment to reflect that.
|
||||
return {"protocols": {"all": {"load": utilisation_category}}}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
return spaces.Dict({"protocols": spaces.Dict({"all": spaces.Dict({"load": spaces.Discrete(11)})})})
|
||||
|
||||
|
||||
class FolderObservation(AbstractObservation):
|
||||
def __init__(self, where: List[str] = [], files: List[FileObservation] = []) -> None:
|
||||
"""Initialise folder Observation, including files inside of the folder.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this folder.
|
||||
A typical location for a file looks like this:
|
||||
['network','nodes',<node_uuid>,'file_system', 'folders',<folder_name>]
|
||||
:type where: Optional[List[str]]
|
||||
:param max_files: As size of the space must remain static, define max files that can be in this folder
|
||||
, defaults to 5
|
||||
:type max_files: int, optional
|
||||
:param file_positions: Defines the positioning within the observation space of particular files. This ensures
|
||||
that even if new files are created, the existing files will always occupy the same space in the observation
|
||||
space. The keys must be between 1 and max_files. Providing file_positions will reserve a spot in the
|
||||
observation space for a file with that name, even if it's temporarily deleted, if it reappears with the same
|
||||
name, it will take the position defined in this dict. Defaults to {}
|
||||
:type file_positions: Dict[int, str], optional
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.where: List[str] = where
|
||||
|
||||
self.files: List[FileObservation] = files
|
||||
|
||||
self.default_observation = {
|
||||
"health_status": 0,
|
||||
"FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)},
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
if not self.where:
|
||||
return self.default_observation
|
||||
folder_state = access_from_nested_dict(state, self.where)
|
||||
if folder_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
health_status = folder_state["health_status"]
|
||||
|
||||
obs = {}
|
||||
|
||||
obs["health_status"] = health_status
|
||||
obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)}
|
||||
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
return spaces.Dict(
|
||||
{
|
||||
"health_status": spaces.Discrete(6),
|
||||
"FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class NicObservation(AbstractObservation):
|
||||
default_observation: spaces.Space = {"nic_status": 0}
|
||||
|
||||
def __init__(self, where: List[str] = []) -> None:
|
||||
super.__init__()
|
||||
self.where: List[str] = where
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
if not self.where:
|
||||
return self.default_observation
|
||||
nic_state = access_from_nested_dict(state, self.where)
|
||||
if nic_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
else:
|
||||
return {"nic_status": 1 if nic_state["enabled"] else 2}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
return spaces.Dict({"nic_status": spaces.Discrete(3)})
|
||||
|
||||
|
||||
class NodeObservation(AbstractObservation):
|
||||
def __init__(
|
||||
self,
|
||||
where: List[str] = [],
|
||||
services: List[ServiceObservation] = [],
|
||||
folders: List[FolderObservation] = [],
|
||||
nics: List[NicObservation] = [],
|
||||
) -> None:
|
||||
"""
|
||||
Configurable observation for a node in the simulation.
|
||||
|
||||
:param where: Where in the simulation state dictionary for find relevant information for this observation.
|
||||
A typical location for a node looks like this:
|
||||
['network','nodes',<node_uuid>]. If empty list, a default null observation will be output, defaults to []
|
||||
:type where: List[str], optional
|
||||
:param services: Mapping between position in observation space and service UUID, defaults to {}
|
||||
:type services: Dict[int,str], optional
|
||||
:param max_services: Max number of services that can be presented in observation space for this node, defaults to 2
|
||||
:type max_services: int, optional
|
||||
:param folders: Mapping between position in observation space and folder name, defaults to {}
|
||||
:type folders: Dict[int,str], optional
|
||||
:param max_folders: Max number of folders in this node's obs space, defaults to 2
|
||||
:type max_folders: int, optional
|
||||
:param nics: Mapping between position in observation space and NIC UUID, defaults to {}
|
||||
:type nics: Dict[int,str], optional
|
||||
:param max_nics: Max number of NICS in this node's obs space, defaults to 5
|
||||
:type max_nics: int, optional
|
||||
"""
|
||||
super.__init__()
|
||||
self.where: List[str] = where
|
||||
|
||||
self.services: List[ServiceObservation] = services
|
||||
self.folders: List[FolderObservation] = folders
|
||||
self.nics: List[NicObservation] = nics
|
||||
|
||||
self.default_observation: Dict = {
|
||||
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
|
||||
"FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)},
|
||||
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.nics)},
|
||||
"operating_status": 0,
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
if not self.where:
|
||||
return self.default_observation
|
||||
|
||||
node_state = access_from_nested_dict(state, self.where)
|
||||
if node_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
|
||||
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
|
||||
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
|
||||
obs["operating_status"] = node_state["operating_state"]
|
||||
obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)}
|
||||
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
return spaces.Dict(
|
||||
{
|
||||
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
|
||||
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
|
||||
"operating_status": spaces.Discrete(0),
|
||||
"NICS": spaces.Dict({i + 1: nic.space for i, nic in enumerate(self.nics)}),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class AclObservation(AbstractObservation):
|
||||
# TODO: should where be optional, and we can use where=None to pad the observation space?
|
||||
# definitely the current approach does not support tracking files that aren't specified by name, for example
|
||||
# if a file is created at runtime, we have currently got no way of telling the observation space to track it.
|
||||
# this needs adding, but not for the MVP.
|
||||
def __init__(
|
||||
self, nodes: List[str], ports: List[int], protocols: list[str], where: List[str] = [], num_rules: int = 10
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.where: List[str] = where
|
||||
self.num_rules: int = num_rules
|
||||
self.node_to_id: Dict[str, int] = {node: i + 1 for i, node in enumerate(nodes)}
|
||||
"List of node IP addresses, order in this list determines how they are converted to an ID"
|
||||
self.port_to_id: Dict[int, int] = {port: i + 1 for i, port in enumerate(ports)}
|
||||
"List of ports which are part of the game that define the ordering when converting to an ID"
|
||||
self.protocol_to_id: Dict[str, int] = {protocol: i + 1 for i, protocol in enumerate(protocols)}
|
||||
"List of protocols which are part of the game, defines ordering when converting to an ID"
|
||||
self.default_observation: spaces.Space = spaces.Dict(
|
||||
{
|
||||
"RULES": spaces.Dict(
|
||||
{
|
||||
i
|
||||
+ 1: spaces.Dict(
|
||||
{
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
)
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
if not self.where:
|
||||
return self.default_observation
|
||||
acl_state: Dict = access_from_nested_dict(state, self.where)
|
||||
if acl_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["RULES"] = {}
|
||||
for i, rule_state in acl_state.items():
|
||||
if rule_state is None:
|
||||
obs["RULES"][i + 1] = {
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
else:
|
||||
obs["RULES"][i + 1] = {
|
||||
"position": i,
|
||||
"permission": rule_state["action"],
|
||||
"source_node_id": self.node_to_id[rule_state["src_ip_address"]],
|
||||
"source_port": self.port_to_id[rule_state["src_port"]],
|
||||
"dest_node_id": self.node_to_id[rule_state["dst_ip_address"]],
|
||||
"dest_port": self.port_to_id[rule_state["dst_port"]],
|
||||
"protocol": self.protocol_to_id[rule_state["protocol"]],
|
||||
}
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
return spaces.Dict(
|
||||
{
|
||||
"RULE": spaces.Dict(
|
||||
{
|
||||
i
|
||||
+ 1: spaces.Dict(
|
||||
{
|
||||
"position": spaces.Discrete(self.num_rules),
|
||||
"permission": spaces.Discrete(3),
|
||||
"source_node_id": spaces.Discrete(len(self.nodes) + 1),
|
||||
"source_port": spaces.Discrete(len(self.ports) + 1),
|
||||
"dest_node_id": spaces.Discrete(len(self.nodes) + 1),
|
||||
"dest_port": spaces.Discrete(len(self.ports) + 1),
|
||||
"protocol": spaces.Discrete(len(self.protocols) + 1),
|
||||
}
|
||||
)
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ICSObservation(AbstractObservation):
|
||||
def observe(self, state: Dict) -> Any:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
return spaces.Discrete(1)
|
||||
|
||||
|
||||
class ObservationSpace:
|
||||
"""Manage the observations of an Actor."""
|
||||
"""
|
||||
Manage the observations of an Actor.
|
||||
|
||||
The observation space has the purpose of:
|
||||
1. Reading the outputted state from the PrimAITE Simulation.
|
||||
2. Selecting parts of the simulation state that are requested by the simulation config
|
||||
3. Formatting this information so an actor can use it to make decisions.
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
# what this class does:
|
||||
# keep a list of observations
|
||||
# create observations for an actor from the config
|
||||
def __init__(
|
||||
self,
|
||||
simulation: Simulation,
|
||||
nodes: List[NodeObservation] = [],
|
||||
links: List[LinkObservation] = [],
|
||||
acl: Optional[AclObservation] = None,
|
||||
ics: Optional[ICSObservation] = None,
|
||||
) -> None:
|
||||
self.simulation: Simulation = simulation
|
||||
self.parts: Dict[str, AbstractObservation] = {}
|
||||
|
||||
self.nodes: List[NodeObservation] = nodes
|
||||
self.links: List[LinkObservation] = links
|
||||
self.acl: Optional[AclObservation] = acl
|
||||
self.ics: Optional[ICSObservation] = ics
|
||||
|
||||
# Example YAML file for agent observation space
|
||||
"""
|
||||
arcd_gate:
|
||||
rl_framework: SB3
|
||||
rl_algo: PPO
|
||||
n_learn_steps: 128
|
||||
n_learn_episodes: 1000
|
||||
def observe(self) -> None:
|
||||
...
|
||||
|
||||
game_layer:
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
type: GREEN
|
||||
node_ref: client_1
|
||||
service: WebBrowser
|
||||
pol:
|
||||
- step: 1
|
||||
action: START
|
||||
@property
|
||||
def space(self) -> None:
|
||||
...
|
||||
|
||||
- ref: client_1_data_manip_red_bot
|
||||
node_ref: client_1
|
||||
service: DataManipulationBot
|
||||
execution_definition:
|
||||
- server_ip_address: 192.168.1.10
|
||||
- server_password:
|
||||
- payload: 'ATTACK'
|
||||
|
||||
pol:
|
||||
- step: 75
|
||||
action: EXECUTE
|
||||
|
||||
|
||||
|
||||
|
||||
simulation:
|
||||
nodes:
|
||||
- ref: client_1
|
||||
hostname: client_1
|
||||
node_type: Computer
|
||||
ip_address: 192.168.10.100
|
||||
services:
|
||||
- name: DataManipulationBot
|
||||
links:
|
||||
endpoint_a:
|
||||
endpoint_b: 1524552-fgfg4147gdh-25gh4gd
|
||||
rewards:
|
||||
|
||||
"""
|
||||
@classmethod
|
||||
def from_config(self) -> None:
|
||||
...
|
||||
|
||||
@@ -855,14 +855,14 @@ class ICMP:
|
||||
class NodeOperatingState(Enum):
|
||||
"""Enumeration of Node Operating States."""
|
||||
|
||||
OFF = 0
|
||||
"The node is powered off."
|
||||
ON = 1
|
||||
"The node is powered on."
|
||||
SHUTTING_DOWN = 2
|
||||
"The node is in the process of shutting down."
|
||||
OFF = 2
|
||||
"The node is powered off."
|
||||
BOOTING = 3
|
||||
"The node is in the process of booting up."
|
||||
SHUTTING_DOWN = 4
|
||||
"The node is in the process of shutting down."
|
||||
|
||||
|
||||
class Node(SimComponent):
|
||||
|
||||
@@ -58,7 +58,14 @@ class ACLRule(SimComponent):
|
||||
|
||||
:return: A dictionary representing the current state.
|
||||
"""
|
||||
pass
|
||||
state = super().describe_state()
|
||||
state["action"] = self.action.value
|
||||
state["protocol"] = self.protocol.value
|
||||
state["src_ip_address"] = self.src_ip_address
|
||||
state["src_port"] = self.src_port.value
|
||||
state["dst_ip_address"] = self.dst_ip_address
|
||||
state["dst_port"] = self.dst_port.value
|
||||
return state
|
||||
|
||||
|
||||
class AccessControlList(SimComponent):
|
||||
@@ -123,7 +130,12 @@ class AccessControlList(SimComponent):
|
||||
|
||||
:return: A dictionary representing the current state.
|
||||
"""
|
||||
pass
|
||||
state = super().describe_state()
|
||||
state["implicit_action"] = self.implicit_action.value
|
||||
state["implicit_rule"] = self.implicit_rule.describe_state()
|
||||
state["max_acl_rules"] = self.max_acl_rules
|
||||
state["acl"] = {i: r.describe_state() if isinstance(r, ACLRule) else None for i, r in enumerate(self._acl)}
|
||||
return state
|
||||
|
||||
@property
|
||||
def acl(self) -> List[Optional[ACLRule]]:
|
||||
@@ -648,7 +660,10 @@ class Router(Node):
|
||||
|
||||
:return: A dictionary representing the current state.
|
||||
"""
|
||||
pass
|
||||
state = super().describe_state()
|
||||
state["num_ports"] = (self.num_ports,)
|
||||
state["acl"] = (self.acl.describe_state(),)
|
||||
return state
|
||||
|
||||
def route_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None:
|
||||
"""
|
||||
|
||||
@@ -15,14 +15,14 @@ class ServiceOperatingState(Enum):
|
||||
"The service is currently running."
|
||||
STOPPED = 2
|
||||
"The service is not running."
|
||||
INSTALLING = 3
|
||||
"The service is being installed or updated."
|
||||
RESTARTING = 4
|
||||
"The service is in the process of restarting."
|
||||
PAUSED = 5
|
||||
PAUSED = 3
|
||||
"The service is temporarily paused."
|
||||
DISABLED = 6
|
||||
DISABLED = 4
|
||||
"The service is disabled and cannot be started."
|
||||
INSTALLING = 5
|
||||
"The service is being installed or updated."
|
||||
RESTARTING = 6
|
||||
"The service is in the process of restarting."
|
||||
|
||||
|
||||
class Service(IOSoftware):
|
||||
@@ -60,7 +60,7 @@ class Service(IOSoftware):
|
||||
:rtype: Dict
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state.update({"operating_state": self.operating_state.name})
|
||||
state.update({"operating_state": self.operating_state.value})
|
||||
return state
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
|
||||
20
tests/integration_tests/game_layer/test_observations.py
Normal file
20
tests/integration_tests/game_layer/test_observations.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from gym import spaces
|
||||
|
||||
from primaite.game.actor.observations import FileObservation
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
|
||||
|
||||
def test_file_observation():
|
||||
sim = Simulation()
|
||||
pc = Computer(hostname="beep", ip_address="123.123.123.123", subnet_mask="255.255.255.0")
|
||||
sim.network.add_node(pc)
|
||||
f = pc.file_system.create_file(file_name="dog.png")
|
||||
|
||||
state = sim.describe_state()
|
||||
|
||||
dog_file_obs = FileObservation(
|
||||
where=["network", "nodes", pc.uuid, "file_system", "folders", "root", "files", "dog.png"]
|
||||
)
|
||||
assert dog_file_obs(state) == {"health_status": 1}
|
||||
assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)})
|
||||
Reference in New Issue
Block a user