diff --git a/synapse_topology/model/__init__.py b/synapse_topology/model/__init__.py index f8f2353a96..eb7d3456ec 100644 --- a/synapse_topology/model/__init__.py +++ b/synapse_topology/model/__init__.py @@ -1,6 +1,7 @@ -import os.path as path - import yaml +import subprocess + +from os.path import abspath, join from synapse.config.homeserver import HomeServerConfig @@ -12,27 +13,14 @@ from .constants import ( SERVER_NAME, ) from .errors import BasConfigInUseError, BaseConfigNotFoundError, ConfigNotFoundError - -import subprocess +from .config import create_config def set_config_dir(conf_dir): global config_dir global data_dir - config_dir = path.abspath(conf_dir) - data_dir = path.abspath(path.join(config_dir, "./data")) - - -def get_config(sub_config=BASE_CONFIG): - if sub_config: - conf_path = path.join(config_dir, sub_config) - try: - with open(conf_path, "r") as f: - return yaml.safe_load(f) - except FileNotFoundError: - raise BaseConfigNotFoundError() if sub_config == BASE_CONFIG else ConfigNotFoundError( - sub_config - ) + config_dir = abspath(conf_dir) + data_dir = abspath(join(config_dir, "./data")) def get_config_dir(): @@ -43,6 +31,18 @@ def get_data_dir(): return data_dir +def get_config(sub_config=BASE_CONFIG): + if sub_config: + conf_path = join(config_dir, sub_config) + try: + with open(conf_path, "r") as f: + return yaml.safe_load(f) + except FileNotFoundError: + raise BaseConfigNotFoundError() if sub_config == BASE_CONFIG else ConfigNotFoundError( + sub_config + ) + + def set_config(config, sub_config=BASE_CONFIG): if sub_config == BASE_CONFIG and config_in_use(): raise BasConfigInUseError() @@ -70,13 +70,13 @@ def generate_base_config(server_name, report_stats): print(config_dir) conf = HomeServerConfig().generate_config( config_dir, - path.join(config_dir, DATA_SUBDIR), + join(config_dir, DATA_SUBDIR), server_name, generate_secrets=True, report_stats=report_stats, ) - with open(path.join(config_dir, BASE_CONFIG), "w") as f: + with open(join(config_dir, BASE_CONFIG), "w") as f: f.write(conf) f.write(CONFIG_LOCK_DATA) @@ -90,7 +90,7 @@ def get_server_name(): def get_secret_key(): config = get_config() server_name = config.get(SERVER_NAME) - signing_key_path = path.join(config_dir, server_name + ".signing.key") + signing_key_path = join(config_dir, server_name + ".signing.key") subprocess.run(["generate_signing_key.py", "-o", signing_key_path]) with open(signing_key_path, "r") as f: return f.read() @@ -101,10 +101,8 @@ def verify_yaml(): def add_certs(cert, cert_key): - with open( - path.join(config_dir, get_server_name() + ".tls.crt"), "w" - ) as cert_file, open( - path.join(config_dir, get_server_name() + ".tls.key"), "w" + with open(join(config_dir, get_server_name() + ".tls.crt"), "w") as cert_file, open( + join(config_dir, get_server_name() + ".tls.key"), "w" ) as key_file: cert_file.write(cert) key_file.write(cert_key)