Merge branch 'dev' into bugfix/2442-add_SubprocVecEnv_support

This commit is contained in:
Nick Todd
2024-05-02 16:59:08 +01:00
47 changed files with 3006 additions and 283 deletions

View File

@@ -122,35 +122,20 @@ class _PrimaitePaths:
PRIMAITE_PATHS: Final[_PrimaitePaths] = _PrimaitePaths()
def _host_primaite_config() -> None:
if not PRIMAITE_PATHS.app_config_file_path.exists():
pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
shutil.copy2(pkg_config_path, PRIMAITE_PATHS.app_config_file_path)
_host_primaite_config()
def _get_primaite_config() -> Dict:
config_path = PRIMAITE_PATHS.app_config_file_path
if not config_path.exists():
# load from package if config does not exist
config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
# generate app config
shutil.copy2(config_path, PRIMAITE_PATHS.app_config_file_path)
with open(config_path, "r") as file:
# load from config
primaite_config = yaml.safe_load(file)
log_level_map = {
"NOTSET": logging.NOTSET,
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARN": logging.WARN,
"WARNING": logging.WARN,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
primaite_config["log_level"] = log_level_map[primaite_config["logging"]["log_level"]]
return primaite_config
return primaite_config
_PRIMAITE_CONFIG = _get_primaite_config()
PRIMAITE_CONFIG = _get_primaite_config()
class _LevelFormatter(Formatter):
@@ -177,11 +162,11 @@ class _LevelFormatter(Formatter):
_LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter(
{
logging.DEBUG: _PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"],
logging.INFO: _PRIMAITE_CONFIG["logging"]["logger_format"]["INFO"],
logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"]["WARNING"],
logging.ERROR: _PRIMAITE_CONFIG["logging"]["logger_format"]["ERROR"],
logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"]["CRITICAL"],
logging.DEBUG: PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"],
logging.INFO: PRIMAITE_CONFIG["logging"]["logger_format"]["INFO"],
logging.WARNING: PRIMAITE_CONFIG["logging"]["logger_format"]["WARNING"],
logging.ERROR: PRIMAITE_CONFIG["logging"]["logger_format"]["ERROR"],
logging.CRITICAL: PRIMAITE_CONFIG["logging"]["logger_format"]["CRITICAL"],
}
)
@@ -193,10 +178,10 @@ _FILE_HANDLER: Final[RotatingFileHandler] = RotatingFileHandler(
backupCount=9, # Max 100MB of logs
encoding="utf8",
)
_STREAM_HANDLER.setLevel(_PRIMAITE_CONFIG["logging"]["log_level"])
_FILE_HANDLER.setLevel(_PRIMAITE_CONFIG["logging"]["log_level"])
_STREAM_HANDLER.setLevel(PRIMAITE_CONFIG["logging"]["log_level"])
_FILE_HANDLER.setLevel(PRIMAITE_CONFIG["logging"]["log_level"])
_LOG_FORMAT_STR: Final[str] = _PRIMAITE_CONFIG["logging"]["logger_format"]
_LOG_FORMAT_STR: Final[str] = PRIMAITE_CONFIG["logging"]["logger_format"]
_STREAM_HANDLER.setFormatter(_LEVEL_FORMATTER)
_FILE_HANDLER.setFormatter(_LEVEL_FORMATTER)
@@ -215,6 +200,6 @@ def getLogger(name: str) -> Logger: # noqa
logging config.
"""
logger = logging.getLogger(name)
logger.setLevel(_PRIMAITE_CONFIG["log_level"])
logger.setLevel(PRIMAITE_CONFIG["logging"]["log_level"])
return logger

View File

@@ -2,16 +2,21 @@
"""Provides a CLI using Typer as an entry point."""
import logging
import os
import shutil
from enum import Enum
from pathlib import Path
from typing import Optional
import pkg_resources
import typer
import yaml
from typing_extensions import Annotated
from primaite import PRIMAITE_PATHS
from primaite.utils.cli import dev_cli
app = typer.Typer(no_args_is_help=True)
app.add_typer(dev_cli.dev, name="dev-mode")
@app.command()
@@ -89,7 +94,7 @@ def version() -> None:
@app.command()
def setup(overwrite_existing: bool = True) -> None:
def setup(overwrite_existing: bool = False) -> None:
"""
Perform the PrimAITE first-time setup.
@@ -102,11 +107,14 @@ def setup(overwrite_existing: bool = True) -> None:
_LOGGER.info("Performing the PrimAITE first-time setup...")
_LOGGER.info("Building primaite_config.yaml...")
_LOGGER.info("Building the PrimAITE app directories...")
PRIMAITE_PATHS.mkdirs()
_LOGGER.info("Building primaite_config.yaml...")
if overwrite_existing:
pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
shutil.copy(pkg_config_path, PRIMAITE_PATHS.app_config_file_path)
_LOGGER.info("Rebuilding the demo notebooks...")
reset_demo_notebooks.run(overwrite_existing=True)
@@ -114,47 +122,3 @@ def setup(overwrite_existing: bool = True) -> None:
reset_example_configs.run(overwrite_existing=True)
_LOGGER.info("PrimAITE setup complete!")
@app.command()
def mode(
dev: Annotated[bool, typer.Option("--dev", help="Activates PrimAITE developer mode")] = None,
prod: Annotated[bool, typer.Option("--prod", help="Activates PrimAITE production mode")] = None,
) -> None:
"""
Switch PrimAITE between developer mode and production mode.
By default, PrimAITE will be in production mode.
To view the current mode, use: primaite mode
To set to development mode, use: primaite mode --dev
To return to production mode, use: primaite mode --prod
"""
if PRIMAITE_PATHS.app_config_file_path.exists():
with open(PRIMAITE_PATHS.app_config_file_path, "r") as file:
primaite_config = yaml.safe_load(file)
if dev and prod:
print("Unable to activate developer and production modes concurrently.")
return
if (dev is None) and (prod is None):
is_dev_mode = primaite_config["developer_mode"]
if is_dev_mode:
print("PrimAITE is running in developer mode.")
else:
print("PrimAITE is running in production mode.")
if dev:
# activate dev mode
primaite_config["developer_mode"] = True
with open(PRIMAITE_PATHS.app_config_file_path, "w") as file:
yaml.dump(primaite_config, file)
print("PrimAITE is running in developer mode.")
if prod:
# activate prod mode
primaite_config["developer_mode"] = False
with open(PRIMAITE_PATHS.app_config_file_path, "w") as file:
yaml.dump(primaite_config, file)
print("PrimAITE is running in production mode.")

View File

@@ -0,0 +1,65 @@
game:
ports:
- ARP
protocols:
- ICMP
- TCP
- UDP
simulation:
network:
nodes:
- hostname: pc_1
type: computer
ip_address: 192.168.1.11
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
- hostname: pc_2
type: computer
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
- hostname: server_1
type: server
ip_address: 192.168.1.13
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
- hostname: switch_1
type: switch
num_ports: 4
- hostname: router_1
type: router
num_ports: 1
ports:
1:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
acl:
10:
action: PERMIT
src_ip: 192.168.1.0
src_wildcard_mask: 0.0.0.255
dst_ip: 192.168.1.1
dst_wildcard_mask: 0.0.0.0
links:
- endpoint_a_hostname: pc_1
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 1
- endpoint_a_hostname: pc_2
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 2
- endpoint_a_hostname: server_1
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 3
- endpoint_a_hostname: router_1
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 4

View File

@@ -0,0 +1,26 @@
game:
ports:
- ARP
protocols:
- ICMP
- TCP
- UDP
simulation:
network:
nodes:
- hostname: pc_1
type: computer
ip_address: 192.168.1.11
subnet_mask: 255.255.255.0
- hostname: server_1
type: server
ip_address: 192.168.1.13
subnet_mask: 255.255.255.0
links:
- endpoint_a_hostname: pc_1
endpoint_a_port: 1
endpoint_b_hostname: server_1
endpoint_b_port: 1

View File

@@ -0,0 +1,439 @@
game:
ports:
- ARP
- DNS
- HTTP
- POSTGRES_SERVER
protocols:
- ICMP
- TCP
- UDP
simulation:
network:
nodes:
# Home/Office Network
- hostname: pc_1
type: computer
ip_address: 192.168.1.11
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 8.8.8.2
applications:
- type: DatabaseClient
options:
db_server_ip: 10.10.1.11
- type: WebBrowser
options:
target_url: http://sometech.ai
- hostname: pc_2
type: computer
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 8.8.8.2
applications:
- type: DatabaseClient
options:
db_server_ip: 10.10.1.11
- type: WebBrowser
options:
target_url: http://sometech.ai
- hostname: server_1
type: server
ip_address: 192.168.1.13
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 8.8.8.2
- hostname: switch_1
type: switch
num_ports: 4
- hostname: router_1
type: router
num_ports: 2
ports:
1:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
2:
ip_address: 43.35.240.2
subnet_mask: 255.255.255.252
acl:
10:
action: PERMIT
default_route: # Default route to all external networks
next_hop_ip_address: 43.35.240.1 # NI 1 on icp_router
# ISP Network
- hostname: isp_rt
type: router
num_ports: 3
ports:
1:
ip_address: 43.35.240.1
subnet_mask: 255.255.255.252
2:
ip_address: 94.10.180.1
subnet_mask: 255.255.255.252
3:
ip_address: 8.8.8.1
subnet_mask: 255.255.255.252
acl:
10:
action: PERMIT
routes:
- address: 192.168.1.0 # Route to the Home/Office LAN
subnet_mask: 255.255.255.0
next_hop_ip_address: 43.35.240.2 # NI 2 on router_1
- address: 10.10.0.0 # Route to the SomeTech internal network
subnet_mask: 255.255.0.0
next_hop_ip_address: 94.10.180.2 # NI ext on some_tech_fw
- address: 94.10.180.6 # Route to the Web Server in the SomeTech DMZ
subnet_mask: 255.255.255.255
next_hop_ip_address: 94.10.180.2 # NI ext on some_tech_fw
- hostname: isp_dns_srv
type: server
ip_address: 8.8.8.2
subnet_mask: 255.255.255.252
default_gateway: 8.8.8.1
services:
- ref: dns_server
type: DNSServer
options:
domain_mapping:
sometech.ai: 94.10.180.6
# SomeTech Network
- hostname: some_tech_fw
type: firewall
ports:
external_port: # port 1
ip_address: 94.10.180.2
subnet_mask: 255.255.255.252
internal_port: # port 2
ip_address: 10.10.4.2
subnet_mask: 255.255.255.252
dmz_port: # port 3
ip_address: 94.10.180.5
subnet_mask: 255.255.255.252
acl:
internal_inbound_acl:
8: # Permit some_tech_web_srv to connect to Database service on some_tech_db_srv
action: PERMIT
src_ip: 94.10.180.6
src_wildcard_mask: 0.0.0.0
src_port: POSTGRES_SERVER
dst_ip: 10.10.1.11
dst_wildcard_mask: 0.0.0.0
dst_port: POSTGRES_SERVER
9: # Permit SomeTech to use HTTP
action: PERMIT
src_port: HTTP
10: # Permit SomeTech to use DNS
action: PERMIT
src_port: DNS
dst_port: DNS
internal_outbound_acl:
10: # Permit all internal outbound traffic
action: PERMIT
dmz_inbound_acl:
7: # Permit Database service on some_tech_db_srv to respond to some_tech_web_srv
action: PERMIT
src_ip: 10.10.1.11
src_port: POSTGRES_SERVER
src_wildcard_mask: 0.0.0.0
dst_ip: 94.10.180.6
dst_port: POSTGRES_SERVER
dst_wildcard_mask: 0.0.0.0
8: # Permit SomeTech DMZ to use ARP
action: PERMIT
src_port: ARP
dst_port: ARP
9: # Permit SomeTech DMZ to use DNS
action: PERMIT
src_port: DNS
dst_port: DNS
10: # Permit all inbound HTTP requests
action: PERMIT
dst_port: HTTP
dmz_outbound_acl:
7: # Permit some_tech_web_srv to connect to Database service on some_tech_db_srv
action: PERMIT
src_ip: 94.10.180.6
src_port: POSTGRES_SERVER
src_wildcard_mask: 0.0.0.0
dst_ip: 10.10.1.11
dst_port: POSTGRES_SERVER
dst_wildcard_mask: 0.0.0.0
8: # Permit SomeTech DMZ to use ARP
action: PERMIT
src_port: ARP
dst_port: ARP
9: # Permit SomeTech DMZ to use DNS
action: PERMIT
src_port: DNS
dst_port: DNS
10: # Permit all outbound HTTP requests
action: PERMIT
src_port: HTTP
default_route: # Default route to all external networks
next_hop_ip_address: 94.10.180.1 # NI 2 on isp_rt
routes:
- address: 10.10.0.0 # Route to the SomeTech internal LAN
subnet_mask: 255.255.0.0
next_hop_ip_address: 10.10.4.1 # NI 1 on some_tech_rt
- hostname: some_tech_web_srv
type: server
ip_address: 94.10.180.6
subnet_mask: 255.255.255.252
default_gateway: 94.10.180.5
dns_server: 8.8.8.2
services:
- ref: web_server
type: WebServer
applications:
- type: DatabaseClient
options:
db_server_ip: 10.10.1.11
- hostname: some_tech_rt
type: router
num_ports: 4
ports:
1:
ip_address: 10.10.1.1
subnet_mask: 255.255.255.0
2:
ip_address: 10.10.4.1
subnet_mask: 255.255.255.252
3:
ip_address: 10.10.3.1
subnet_mask: 255.255.255.0
4:
ip_address: 10.10.2.1
subnet_mask: 255.255.255.0
acl:
2: # Allow the some_tech_web_srv to connect to the Database Service on some_tech_db_srv
action: PERMIT
src_ip: 94.10.180.6
src_wildcard_mask: 0.0.0.0
src_port: POSTGRES_SERVER
dst_ip: 10.10.1.11
dst_wildcard_mask: 0.0.0.0
dst_port: POSTGRES_SERVER
3: # Allow the Database Service on some_tech_db_srv to respond to some_tech_web_srv
action: PERMIT
src_ip: 10.10.1.11
src_wildcard_mask: 0.0.0.0
src_port: POSTGRES_SERVER
dst_ip: 94.10.180.6
dst_wildcard_mask: 0.0.0.0
dst_port: POSTGRES_SERVER
4: # Prevent the Junior engineer from downloading files from the some_tech_storage_srv over FTP
action: DENY
src_ip: 10.10.2.12
src_wildcard_mask: 0.0.0.0
src_port: FTP
dst_ip: 10.10.1.12
dst_wildcard_mask: 0.0.0.0
dst_port: FTP
5: # Allow communication between Engineering and the DB & Storage subnet
action: PERMIT
src_ip: 10.10.2.0
src_wildcard_mask: 0.0.0.255
dst_ip: 10.10.1.0
dst_wildcard_mask: 0.0.0.255
6: # Allow communication between the DB & Storage subnet and Engineering
action: PERMIT
src_ip: 10.10.1.0
src_wildcard_mask: 0.0.0.255
dst_ip: 10.10.2.0
dst_wildcard_mask: 0.0.0.255
7: # Allow the SomeTech network to use HTTP
action: PERMIT
src_port: HTTP
dst_port: HTTP
8: # Allow the SomeTech internal network to use ARP
action: PERMIT
src_ip: 10.10.0.0
src_wildcard_mask: 0.0.255.255
src_port: ARP
9: # Allow the SomeTech internal network to use ICMP
action: PERMIT
src_ip: 10.10.0.0
src_wildcard_mask: 0.0.255.255
protocol: ICMP
10:
action: PERMIT
src_ip: 94.10.180.6
src_wildcard_mask: 0.0.0.0
src_port: HTTP
dst_ip: 10.10.0.0
dst_wildcard_mask: 0.0.255.255
dst_port: HTTP
11: # Permit SomeTech to use DNS
action: PERMIT
src_port: DNS
dst_port: DNS
default_route: # Default route to all external networks
next_hop_ip_address: 10.10.4.2 # NI int on some_tech_fw
- hostname: some_tech_data_sw
type: switch
num_ports: 3
- hostname: some_tech_hr_sw
type: switch
num_ports: 2
- hostname: some_tech_eng_sw
type: switch
num_ports: 3
- hostname: some_tech_db_srv
type: server
ip_address: 10.10.1.11
subnet_mask: 255.255.255.0
default_gateway: 10.10.1.1
dns_server: 8.8.8.2
services:
- type: DatabaseService
options:
backup_server_ip: 10.10.1.12 # The some_tech_storage_srv server
- type: FTPClient
- hostname: some_tech_storage_srv
type: server
ip_address: 10.10.1.12
subnet_mask: 255.255.255.0
default_gateway: 10.10.1.1
dns_server: 8.8.8.2
services:
- type: FTPServer
- hostname: some_tech_hr_1
type: computer
ip_address: 10.10.3.11
subnet_mask: 255.255.255.0
default_gateway: 10.10.3.1
dns_server: 8.8.8.2
applications:
- type: DatabaseClient
options:
db_server_ip: 10.10.1.11
- type: WebBrowser
options:
target_url: http://sometech.ai
- hostname: some_tech_snr_dev_pc
type: computer
ip_address: 10.10.2.11
subnet_mask: 255.255.255.0
default_gateway: 10.10.2.1
dns_server: 8.8.8.2
applications:
- type: DatabaseClient
options:
db_server_ip: 10.10.1.11
- type: WebBrowser
options:
target_url: http://sometech.ai
- hostname: some_tech_jnr_dev_pc
type: computer
ip_address: 10.10.2.12
subnet_mask: 255.255.255.0
default_gateway: 10.10.2.1
dns_server: 8.8.8.2
applications:
- type: DatabaseClient
options:
db_server_ip: 10.10.1.11
- type: WebBrowser
options:
target_url: http://sometech.ai
links:
# Home/Office Lan Links
- endpoint_a_hostname: pc_1
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 1
- endpoint_a_hostname: pc_2
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 2
- endpoint_a_hostname: server_1
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 3
- endpoint_a_hostname: router_1
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 4
# ISP Links
- endpoint_a_hostname: isp_rt
endpoint_a_port: 1
endpoint_b_hostname: router_1
endpoint_b_port: 2
- endpoint_a_hostname: isp_rt
endpoint_a_port: 2
endpoint_b_hostname: some_tech_fw
endpoint_b_port: 1
- endpoint_a_hostname: isp_rt
endpoint_a_port: 3
endpoint_b_hostname: isp_dns_srv
endpoint_b_port: 1
# SomeTech LAN Links
- endpoint_a_hostname: some_tech_fw
endpoint_a_port: 3
endpoint_b_hostname: some_tech_web_srv
endpoint_b_port: 1
- endpoint_a_hostname: some_tech_fw
endpoint_a_port: 2
endpoint_b_hostname: some_tech_rt
endpoint_b_port: 2
- endpoint_a_hostname: some_tech_rt
endpoint_a_port: 1
endpoint_b_hostname: some_tech_data_sw
endpoint_b_port: 3
- endpoint_a_hostname: some_tech_rt
endpoint_a_port: 3
endpoint_b_hostname: some_tech_hr_sw
endpoint_b_port: 2
- endpoint_a_hostname: some_tech_rt
endpoint_a_port: 4
endpoint_b_hostname: some_tech_eng_sw
endpoint_b_port: 3
- endpoint_a_hostname: some_tech_data_sw
endpoint_a_port: 1
endpoint_b_hostname: some_tech_db_srv
endpoint_b_port: 1
- endpoint_a_hostname: some_tech_data_sw
endpoint_a_port: 2
endpoint_b_hostname: some_tech_storage_srv
endpoint_b_port: 1
- endpoint_a_hostname: some_tech_hr_sw
endpoint_a_port: 1
endpoint_b_hostname: some_tech_hr_1
endpoint_b_port: 1
- endpoint_a_hostname: some_tech_eng_sw
endpoint_a_port: 1
endpoint_b_hostname: some_tech_snr_dev_pc
endpoint_b_port: 1
- endpoint_a_hostname: some_tech_eng_sw
endpoint_a_port: 2
endpoint_b_hostname: some_tech_jnr_dev_pc
endpoint_b_port: 1

View File

@@ -14,7 +14,6 @@ from primaite.game.agent.scripted_agents.probabilistic_agent import Probabilisti
from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
from primaite.game.agent.scripted_agents.tap001 import TAP001
from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator.network.airspace import AIR_SPACE
from primaite.simulator.network.hardware.base import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
@@ -222,7 +221,6 @@ class PrimaiteGame:
:return: A PrimaiteGame object.
:rtype: PrimaiteGame
"""
AIR_SPACE.clear()
game = cls()
game.options = PrimaiteGameOptions(**cfg["game"])
game.save_step_metadata = cfg.get("io_settings", {}).get("save_step_metadata") or False
@@ -244,7 +242,7 @@ class PrimaiteGame:
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")),
default_gateway=node_cfg["default_gateway"],
default_gateway=node_cfg.get("default_gateway"),
dns_server=node_cfg.get("dns_server", None),
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
@@ -255,7 +253,7 @@ class PrimaiteGame:
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")),
default_gateway=node_cfg["default_gateway"],
default_gateway=node_cfg.get("default_gateway"),
dns_server=node_cfg.get("dns_server", None),
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
@@ -274,7 +272,7 @@ class PrimaiteGame:
elif n_type == "firewall":
new_node = Firewall.from_config(node_cfg)
elif n_type == "wireless_router":
new_node = WirelessRouter.from_config(node_cfg)
new_node = WirelessRouter.from_config(node_cfg, airspace=net.airspace)
elif n_type == "printer":
new_node = Printer(
hostname=node_cfg["hostname"],

View File

@@ -404,7 +404,7 @@
" # don't flatten observations so that we can see what is going on\n",
" cfg['agents'][3]['agent_settings']['flatten_obs'] = False\n",
"\n",
"env = PrimaiteGymEnv(game_config = cfg)\n",
"env = PrimaiteGymEnv(env_config = cfg)\n",
"obs, info = env.reset()\n",
"print('env created successfully')\n",
"pprint(obs)"

View File

@@ -59,7 +59,7 @@
"metadata": {},
"outputs": [],
"source": [
"gym = PrimaiteGymEnv(game_config=cfg)"
"gym = PrimaiteGymEnv(env_config=cfg)"
]
},
{

View File

@@ -35,7 +35,6 @@ class PrimaiteGymEnv(gymnasium.Env):
"""Current game."""
self._agent_name = next(iter(self.game.rl_agents))
"""Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key."""
self.episode_counter: int = 0
"""Current episode number."""
@@ -49,8 +48,8 @@ class PrimaiteGymEnv(gymnasium.Env):
# make ProxyAgent store the action chosen by the RL policy
step = self.game.step_counter
self.agent.store_action(action)
# apply_agent_actions accesses the action we just stored
self.game.pre_timestep()
# apply_agent_actions accesses the action we just stored
self.game.apply_agent_actions()
self.game.advance_timestep()
state = self.game.get_sim_state()
@@ -58,6 +57,7 @@ class PrimaiteGymEnv(gymnasium.Env):
next_obs = self._get_obs() # this doesn't update observation, just gets the current observation
reward = self.agent.reward_function.current_reward
_LOGGER.info(f"step: {self.game.step_counter}, Blue reward: {reward}")
terminated = False
truncated = self.game.calculate_truncated()
info = {
@@ -204,9 +204,13 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()}
_LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}")
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
self.episode_counter += 1
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
self.game.setup_for_episode(episode=self.episode_counter)
@@ -244,6 +248,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
# 4. Get rewards
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
_LOGGER.info(f"step: {self.game.step_counter}, Rewards: {rewards}")
terminateds = {name: False for name, _ in self.agents.items()}
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
infos = {name: {} for name, _ in self.agents.items()}

View File

@@ -5,9 +5,9 @@ from typing import Dict, List, Optional
from pydantic import BaseModel, ConfigDict
from primaite import getLogger, PRIMAITE_PATHS
from primaite import _PRIMAITE_ROOT, getLogger, PRIMAITE_CONFIG, PRIMAITE_PATHS
from primaite.simulator import LogLevel, SIM_OUTPUT
from src.primaite.utils.primaite_config_utils import is_dev_mode
from primaite.utils.cli.primaite_config_utils import is_dev_mode
_LOGGER = getLogger(__name__)
@@ -62,12 +62,15 @@ class PrimaiteIO:
date_str = timestamp.strftime("%Y-%m-%d")
time_str = timestamp.strftime("%H-%M-%S")
session_path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str
# check if running in dev mode
if is_dev_mode():
# if dev mode, simulation output will be the current working directory
session_path = Path.cwd() / "simulation_output" / date_str / time_str
else:
session_path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str
session_path = _PRIMAITE_ROOT.parent.parent / "sessions" / date_str / time_str
# check if there is an output directory set in config
if PRIMAITE_CONFIG["developer_mode"]["output_dir"]:
session_path = Path(PRIMAITE_CONFIG["developer_mode"]["output_dir"]) / "sessions" / date_str / time_str
session_path.mkdir(exist_ok=True, parents=True)
return session_path

View File

@@ -1,6 +1,12 @@
# The main PrimAITE application config file
developer_mode: False # false by default
developer_mode:
enabled: False # not enabled by default
sys_log_level: DEBUG # level of output for system logs, DEBUG by default
output_sys_logs: False # system logs not output by default
output_pcap_logs: False # pcap logs not output by default
output_to_terminal: False # do not output to terminal by default
output_dir: null # none by default - none will print to repository root
# Logging
logging:

View File

@@ -3,10 +3,12 @@ from datetime import datetime
from enum import IntEnum
from pathlib import Path
from primaite import _PRIMAITE_ROOT
from primaite import _PRIMAITE_ROOT, PRIMAITE_CONFIG, PRIMAITE_PATHS
__all__ = ["SIM_OUTPUT"]
from primaite.utils.cli.primaite_config_utils import is_dev_mode
class LogLevel(IntEnum):
"""Enum containing all the available log levels for PrimAITE simulation output."""
@@ -25,16 +27,34 @@ class LogLevel(IntEnum):
class _SimOutput:
def __init__(self):
self._path: Path = (
_PRIMAITE_ROOT.parent.parent / "simulation_output" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
)
self.save_pcap_logs: bool = False
self.save_sys_logs: bool = False
self.write_sys_log_to_terminal: bool = False
self.sys_log_level: LogLevel = LogLevel.WARNING # default log level is at WARNING
date_str = datetime.now().strftime("%Y-%m-%d")
time_str = datetime.now().strftime("%H-%M-%S")
path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str
self._path = path
self._save_pcap_logs: bool = False
self._save_sys_logs: bool = False
self._write_sys_log_to_terminal: bool = False
self._sys_log_level: LogLevel = LogLevel.WARNING # default log level is at WARNING
@property
def path(self) -> Path:
if is_dev_mode():
date_str = datetime.now().strftime("%Y-%m-%d")
time_str = datetime.now().strftime("%H-%M-%S")
# if dev mode is enabled, if output dir is not set, print to primaite repo root
path: Path = _PRIMAITE_ROOT.parent.parent / "sessions" / date_str / time_str / "simulation_output"
# otherwise print to output dir
if PRIMAITE_CONFIG["developer_mode"]["output_dir"]:
path: Path = (
Path(PRIMAITE_CONFIG["developer_mode"]["output_dir"])
/ "sessions"
/ date_str
/ time_str
/ "simulation_output"
)
self._path = path
return self._path
@path.setter
@@ -42,5 +62,45 @@ class _SimOutput:
self._path = new_path
self._path.mkdir(exist_ok=True, parents=True)
@property
def save_pcap_logs(self) -> bool:
if is_dev_mode():
return PRIMAITE_CONFIG.get("developer_mode").get("output_pcap_logs")
return self._save_pcap_logs
@save_pcap_logs.setter
def save_pcap_logs(self, save_pcap_logs: bool) -> None:
self._save_pcap_logs = save_pcap_logs
@property
def save_sys_logs(self) -> bool:
if is_dev_mode():
return PRIMAITE_CONFIG.get("developer_mode").get("output_sys_logs")
return self._save_sys_logs
@save_sys_logs.setter
def save_sys_logs(self, save_sys_logs: bool) -> None:
self._save_sys_logs = save_sys_logs
@property
def write_sys_log_to_terminal(self) -> bool:
if is_dev_mode():
return PRIMAITE_CONFIG.get("developer_mode").get("output_to_terminal")
return self._write_sys_log_to_terminal
@write_sys_log_to_terminal.setter
def write_sys_log_to_terminal(self, write_sys_log_to_terminal: bool) -> None:
self._write_sys_log_to_terminal = write_sys_log_to_terminal
@property
def sys_log_level(self) -> LogLevel:
if is_dev_mode():
return LogLevel[PRIMAITE_CONFIG.get("developer_mode").get("sys_log_level")]
return self._sys_log_level
@sys_log_level.setter
def sys_log_level(self, sys_log_level: LogLevel) -> None:
self._sys_log_level = sys_log_level
SIM_OUTPUT = _SimOutput()

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, Final, List, Optional
from typing import Any, Dict, List, Optional
from prettytable import PrettyTable
@@ -14,7 +14,7 @@ from primaite.simulator.system.core.packet_capture import PacketCapture
_LOGGER = getLogger(__name__)
__all__ = ["AIR_SPACE", "AirSpaceFrequency", "WirelessNetworkInterface", "IPWirelessNetworkInterface"]
__all__ = ["AirSpaceFrequency", "WirelessNetworkInterface", "IPWirelessNetworkInterface"]
class AirSpace:
@@ -100,18 +100,6 @@ class AirSpace:
wireless_interface.receive_frame(frame)
AIR_SPACE: Final[AirSpace] = AirSpace()
"""
A singleton instance of the AirSpace class, representing the global wireless airspace.
This instance acts as the central management point for all wireless communications within the simulated network
environment. By default, there is only one airspace in the simulation, making this variable a singleton that
manages the registration, removal, and transmission of wireless frames across all wireless network interfaces configured
in the simulation. It ensures that wireless frames are appropriately transmitted to and received by wireless
interfaces based on their operational status and frequency band.
"""
class AirSpaceFrequency(Enum):
"""Enumeration representing the operating frequencies for wireless communications."""
@@ -149,6 +137,7 @@ class WirelessNetworkInterface(NetworkInterface, ABC):
and may define additional properties and methods specific to wireless technology.
"""
airspace: AirSpace
frequency: AirSpaceFrequency = AirSpaceFrequency.WIFI_2_4
def enable(self):
@@ -171,7 +160,7 @@ class WirelessNetworkInterface(NetworkInterface, ABC):
self.pcap = PacketCapture(
hostname=self._connected_node.hostname, port_num=self.port_num, port_name=self.port_name
)
AIR_SPACE.add_wireless_interface(self)
self.airspace.add_wireless_interface(self)
def disable(self):
"""Disable the network interface."""
@@ -182,7 +171,7 @@ class WirelessNetworkInterface(NetworkInterface, ABC):
self._connected_node.sys_log.info(f"Network Interface {self} disabled")
else:
_LOGGER.debug(f"Interface {self} disabled")
AIR_SPACE.remove_wireless_interface(self)
self.airspace.remove_wireless_interface(self)
def send_frame(self, frame: Frame) -> bool:
"""
@@ -198,7 +187,7 @@ class WirelessNetworkInterface(NetworkInterface, ABC):
if self.enabled:
frame.set_sent_timestamp()
self.pcap.capture_outbound(frame)
AIR_SPACE.transmit(frame, self)
self.airspace.transmit(frame, self)
return True
# Cannot send Frame as the network interface is not enabled
return False

View File

@@ -5,9 +5,11 @@ import matplotlib.pyplot as plt
import networkx as nx
from networkx import MultiGraph
from prettytable import MARKDOWN, PrettyTable
from pydantic import Field
from primaite import getLogger
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.network.airspace import AirSpace
from primaite.simulator.network.hardware.base import Link, Node, WiredNetworkInterface
from primaite.simulator.network.hardware.nodes.host.server import Printer
from primaite.simulator.system.applications.application import Application
@@ -28,7 +30,9 @@ class Network(SimComponent):
"""
nodes: Dict[str, Node] = {}
links: Dict[str, Link] = {}
airspace: AirSpace = Field(default_factory=lambda: AirSpace())
_node_id_map: Dict[int, Node] = {}
_link_id_map: Dict[int, Node] = {}

View File

@@ -330,7 +330,7 @@ class Firewall(Router):
# check if External Inbound ACL Rules permit frame
permitted, rule = self.external_inbound_acl.is_permitted(frame)
if not permitted:
self.sys_log.info(f"Frame blocked at interface {from_network_interface} by rule {rule}")
self.sys_log.info(f"Frame blocked at external inbound by rule {rule}")
return
self.software_manager.arp.add_arp_cache_entry(
ip_address=frame.ip.src_ip_address,
@@ -360,7 +360,7 @@ class Firewall(Router):
# check if External Outbound ACL Rules permit frame
permitted, rule = self.external_outbound_acl.is_permitted(frame=frame)
if not permitted:
self.sys_log.info(f"Frame blocked at interface {from_network_interface} by rule {rule}")
self.sys_log.info(f"Frame blocked at external outbound by rule {rule}")
return
self.process_frame(frame=frame, from_network_interface=from_network_interface)
@@ -380,7 +380,7 @@ class Firewall(Router):
# check if Internal Inbound ACL Rules permit frame
permitted, rule = self.internal_inbound_acl.is_permitted(frame=frame)
if not permitted:
self.sys_log.info(f"Frame blocked at interface {from_network_interface} by rule {rule}")
self.sys_log.info(f"Frame blocked at internal inbound by rule {rule}")
return
self.process_frame(frame=frame, from_network_interface=from_network_interface)
@@ -398,7 +398,7 @@ class Firewall(Router):
"""
permitted, rule = self.internal_outbound_acl.is_permitted(frame)
if not permitted:
self.sys_log.info(f"Frame blocked at interface {from_network_interface} by rule {rule}")
self.sys_log.info(f"Frame blocked at internal outbound by rule {rule}")
return
self.software_manager.arp.add_arp_cache_entry(
ip_address=frame.ip.src_ip_address,
@@ -432,7 +432,7 @@ class Firewall(Router):
# check if DMZ Inbound ACL Rules permit frame
permitted, rule = self.dmz_inbound_acl.is_permitted(frame=frame)
if not permitted:
self.sys_log.info(f"Frame blocked at interface {from_network_interface} by rule {rule}")
self.sys_log.info(f"Frame blocked at DMZ inbound by rule {rule}")
return
self.process_frame(frame=frame, from_network_interface=from_network_interface)
@@ -452,7 +452,7 @@ class Firewall(Router):
"""
permitted, rule = self.dmz_outbound_acl.is_permitted(frame)
if not permitted:
self.sys_log.info(f"Frame blocked at interface {from_network_interface} by rule {rule}")
self.sys_log.info(f"Frame blocked at DMZ outbound by rule {rule}")
return
self.software_manager.arp.add_arp_cache_entry(
ip_address=frame.ip.src_ip_address,
@@ -688,4 +688,9 @@ class Firewall(Router):
next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")),
metric=float(route.get("metric", 0)),
)
if "default_route" in cfg:
next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None)
if next_hop_ip_address:
firewall.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address)
return firewall

View File

@@ -1546,7 +1546,7 @@ class Router(NetworkNode):
print(table)
@classmethod
def from_config(cls, cfg: dict) -> "Router":
def from_config(cls, cfg: dict, **kwargs) -> "Router":
"""Create a router based on a config dict.
Schema:

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, Union
from pydantic import validate_call
from primaite.simulator.network.airspace import AirSpaceFrequency, IPWirelessNetworkInterface
from primaite.simulator.network.airspace import AirSpace, AirSpaceFrequency, IPWirelessNetworkInterface
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router, RouterInterface
from primaite.simulator.network.transmission.data_link_layer import Frame
@@ -121,11 +121,14 @@ class WirelessRouter(Router):
network_interfaces: Dict[str, Union[RouterInterface, WirelessAccessPoint]] = {}
network_interface: Dict[int, Union[RouterInterface, WirelessAccessPoint]] = {}
airspace: AirSpace
def __init__(self, hostname: str, **kwargs):
super().__init__(hostname=hostname, num_ports=0, **kwargs)
def __init__(self, hostname: str, airspace: AirSpace, **kwargs):
super().__init__(hostname=hostname, num_ports=0, airspace=airspace, **kwargs)
self.connect_nic(WirelessAccessPoint(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0"))
self.connect_nic(
WirelessAccessPoint(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", airspace=airspace)
)
self.connect_nic(RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0"))
@@ -215,7 +218,7 @@ class WirelessRouter(Router):
)
@classmethod
def from_config(cls, cfg: Dict) -> "WirelessRouter":
def from_config(cls, cfg: Dict, **kwargs) -> "WirelessRouter":
"""Generate the wireless router from config.
Schema:
@@ -245,7 +248,7 @@ class WirelessRouter(Router):
operating_state = (
NodeOperatingState.ON if not (p := cfg.get("operating_state")) else NodeOperatingState[p.upper()]
)
router = cls(hostname=cfg["hostname"], operating_state=operating_state)
router = cls(hostname=cfg["hostname"], operating_state=operating_state, airspace=kwargs["airspace"])
if "router_interface" in cfg:
ip_address = cfg["router_interface"]["ip_address"]
subnet_mask = cfg["router_interface"]["subnet_mask"]

View File

@@ -1,5 +1,9 @@
from ipaddress import IPv4Address
import yaml
from primaite import getLogger, PRIMAITE_PATHS
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
@@ -15,6 +19,8 @@ from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.web_server.web_server import WebServer
_LOGGER = getLogger(__name__)
def client_server_routed() -> Network:
"""
@@ -279,3 +285,34 @@ def arcd_uc2_network() -> Network:
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=3)
return network
def _get_example_network(path: str) -> Network:
try:
with open(path, "r") as file:
cfg = yaml.safe_load(file)
except FileNotFoundError:
msg = f"Failed to locate example network config {path}. Run `primaite setup` to load the example config files."
_LOGGER.error(msg)
raise FileNotFoundError(msg)
game = PrimaiteGame.from_config(cfg)
return game.simulation.network
def client_server_p2p_network_example() -> Network:
"""Get the Client-Server P2P example network."""
path = PRIMAITE_PATHS.user_config_path / "example_config" / "client_server_p2p_network_example.yaml"
return _get_example_network(path)
def basic_lan_network_example() -> Network:
"""Get the basic LAN example network."""
path = PRIMAITE_PATHS.user_config_path / "example_config" / "basic_network_network_example.yaml"
return _get_example_network(path)
def multi_lan_internet_network_example() -> Network:
"""Get Multi-LAN with Internet example network."""
path = PRIMAITE_PATHS.user_config_path / "example_config" / "multi_lan_internet_network_example.yaml"
return _get_example_network(path)

View File

View File

@@ -0,0 +1,171 @@
import click
import typer
from rich import print
from rich.table import Table
from typing_extensions import Annotated
from primaite import _PRIMAITE_ROOT, PRIMAITE_CONFIG
from primaite.simulator import LogLevel
from primaite.utils.cli.primaite_config_utils import is_dev_mode, update_primaite_application_config
dev = typer.Typer()
PRODUCTION_MODE_MESSAGE = (
"\n[green]:rocket::rocket::rocket: "
" PrimAITE is running in Production mode "
" :rocket::rocket::rocket: [/green]\n"
)
DEVELOPER_MODE_MESSAGE = (
"\n[yellow] :construction::construction::construction: "
" PrimAITE is running in Development mode "
" :construction::construction::construction: [/yellow]\n"
)
def dev_mode():
"""
CLI commands relevant to the dev-mode for PrimAITE.
The dev-mode contains tools that help with the ease of developing or debugging PrimAITE.
By default, PrimAITE will be in production mode.
To enable development mode, use `primaite dev-mode enable`
"""
@dev.command()
def show():
"""Show if PrimAITE is in development mode or production mode."""
# print if dev mode is enabled
print(DEVELOPER_MODE_MESSAGE if is_dev_mode() else PRODUCTION_MODE_MESSAGE)
table = Table(title="Current Dev-Mode Settings")
table.add_column("Setting", style="cyan")
table.add_column("Value", style="default")
for setting, value in PRIMAITE_CONFIG["developer_mode"].items():
table.add_row(setting, str(value))
print(table)
print("\nTo see available options, use [cyan]`primaite dev-mode --help`[/cyan]\n")
@dev.command()
def enable():
"""Enable the development mode for PrimAITE."""
# enable dev mode
PRIMAITE_CONFIG["developer_mode"]["enabled"] = True
update_primaite_application_config()
print(DEVELOPER_MODE_MESSAGE)
@dev.command()
def disable():
"""Disable the development mode for PrimAITE."""
# disable dev mode
PRIMAITE_CONFIG["developer_mode"]["enabled"] = False
update_primaite_application_config()
print(PRODUCTION_MODE_MESSAGE)
def config_callback(
ctx: typer.Context,
sys_log_level: Annotated[
LogLevel,
typer.Option(
"--sys-log-level",
"-level",
click_type=click.Choice(LogLevel._member_names_, case_sensitive=False),
help="The level of system logs to output.",
show_default=False,
),
] = None,
output_sys_logs: Annotated[
bool,
typer.Option(
"--output-sys-logs/--no-sys-logs", "-sys/-nsys", help="Output system logs to file.", show_default=False
),
] = None,
output_pcap_logs: Annotated[
bool,
typer.Option(
"--output-pcap-logs/--no-pcap-logs",
"-pcap/-npcap",
help="Output network packet capture logs to file.",
show_default=False,
),
] = None,
output_to_terminal: Annotated[
bool,
typer.Option(
"--output-to-terminal/--no-terminal", "-t/-nt", help="Output system logs to terminal.", show_default=False
),
] = None,
):
"""Configure the development tools and environment."""
if ctx.params.get("sys_log_level") is not None:
PRIMAITE_CONFIG["developer_mode"]["sys_log_level"] = ctx.params.get("sys_log_level")
print(f"PrimAITE dev-mode config updated sys_log_level={ctx.params.get('sys_log_level')}")
if output_sys_logs is not None:
PRIMAITE_CONFIG["developer_mode"]["output_sys_logs"] = output_sys_logs
print(f"PrimAITE dev-mode config updated {output_sys_logs=}")
if output_pcap_logs is not None:
PRIMAITE_CONFIG["developer_mode"]["output_pcap_logs"] = output_pcap_logs
print(f"PrimAITE dev-mode config updated {output_pcap_logs=}")
if output_to_terminal is not None:
PRIMAITE_CONFIG["developer_mode"]["output_to_terminal"] = output_to_terminal
print(f"PrimAITE dev-mode config updated {output_to_terminal=}")
# update application config
update_primaite_application_config()
config_typer = typer.Typer(
callback=config_callback,
name="config",
no_args_is_help=True,
invoke_without_command=True,
)
dev.add_typer(config_typer)
@config_typer.command()
def path(
directory: Annotated[
str,
typer.Argument(
help="Directory where the system logs and PCAP logs will be output. By default, this will be where the"
"root of the PrimAITE repository is located.",
show_default=False,
),
] = None,
default: Annotated[
bool,
typer.Option(
"--default",
"-root",
help="Set PrimAITE to output system logs and pcap logs to the PrimAITE repository root.",
),
] = None,
):
"""Set the output directory for the PrimAITE system and PCAP logs."""
if default:
PRIMAITE_CONFIG["developer_mode"]["output_dir"] = None
# update application config
update_primaite_application_config()
print(
f"PrimAITE dev-mode output_dir [cyan]"
f"{str(_PRIMAITE_ROOT.parent.parent / 'simulation_output')}"
f"[/cyan]"
)
return
if directory:
PRIMAITE_CONFIG["developer_mode"]["output_dir"] = directory
# update application config
update_primaite_application_config()
print(f"PrimAITE dev-mode output_dir [cyan]{directory}[/cyan]")

View File

@@ -0,0 +1,22 @@
from typing import Dict, Optional
import yaml
from primaite import PRIMAITE_CONFIG, PRIMAITE_PATHS
def is_dev_mode() -> bool:
"""Returns True if PrimAITE is currently running in developer mode."""
return PRIMAITE_CONFIG.get("developer_mode", {}).get("enabled", False)
def update_primaite_application_config(config: Optional[Dict] = None) -> None:
"""
Update the PrimAITE application config file.
:params: config: Leave empty so that PRIMAITE_CONFIG is used - otherwise provide the Dict
"""
with open(PRIMAITE_PATHS.app_config_file_path, "w") as file:
if not config:
config = PRIMAITE_CONFIG
yaml.dump(config, file)

View File

@@ -1,11 +0,0 @@
import yaml
from primaite import PRIMAITE_PATHS
def is_dev_mode() -> bool:
"""Returns True if PrimAITE is currently running in developer mode."""
if PRIMAITE_PATHS.app_config_file_path.exists():
with open(PRIMAITE_PATHS.app_config_file_path, "r") as file:
primaite_config = yaml.safe_load(file)
return primaite_config["developer_mode"]