diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index c42088ba9e..b7da0ff431 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -2,6 +2,7 @@ import json import logging import os +import tempfile from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.auth_password import ( @@ -53,9 +54,9 @@ def __init__( if not self.check_exist(): """不存在时载入默认配置""" - with open(config_path, "w", encoding="utf-8-sig") as f: - json.dump(default_config, f, indent=4, ensure_ascii=False) - object.__setattr__(self, "first_deploy", True) # 标记第一次部署 + self.update(default_config) + self.save_config(indent=4) + object.__setattr__(self, "first_deploy", True) # 标记第一次部署 with open(config_path, encoding="utf-8-sig") as f: conf_str = f.read() @@ -211,15 +212,33 @@ def check_config_integrity(self, refer_conf: dict, conf: dict, path=""): return has_new - def save_config(self, replace_config: dict | None = None) -> None: + def save_config( + self, replace_config: dict | None = None, *, indent: int = 2 + ) -> None: """将配置写入文件 如果传入 replace_config,则将配置替换为 replace_config """ if replace_config: self.update(replace_config) - with open(self.config_path, "w", encoding="utf-8-sig") as f: - json.dump(self, f, indent=2, ensure_ascii=False) + directory = os.path.dirname(os.path.abspath(self.config_path)) or "." + fd, temp_path = tempfile.mkstemp( + dir=directory, + prefix=f".{os.path.basename(self.config_path)}.", + suffix=".tmp", + ) + try: + with os.fdopen(fd, "w", encoding="utf-8-sig") as f: + json.dump(self, f, indent=indent, ensure_ascii=False) + f.flush() + os.fsync(f.fileno()) + os.replace(temp_path, self.config_path) + except Exception: + try: + os.unlink(temp_path) + except FileNotFoundError: + pass + raise def __getattr__(self, item): try: diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 7afe82ebed..d6747930d2 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -529,6 +529,38 @@ def test_save_config_with_replace(self, temp_config_path, minimal_default_config # Original fields are preserved because update merges assert "platform_settings" in loaded_config + def test_save_config_preserves_existing_file_when_write_fails( + self, temp_config_path, minimal_default_config, monkeypatch + ): + """Config saves should not corrupt the existing file on write failure.""" + config = AstrBotConfig( + config_path=temp_config_path, default_config=minimal_default_config + ) + with open(temp_config_path, encoding="utf-8-sig") as f: + original_content = f.read() + + def failing_dump(*args, **kwargs): + file_obj = args[1] + file_obj.write("{") + raise RuntimeError("simulated interrupted write") + + config.new_field = "new_value" + monkeypatch.setattr( + "astrbot.core.config.astrbot_config.json.dump", + failing_dump, + ) + + with pytest.raises(RuntimeError, match="simulated interrupted write"): + config.save_config() + + with open(temp_config_path, encoding="utf-8-sig") as f: + assert f.read() == original_content + assert [ + entry.name + for entry in os.scandir(os.path.dirname(temp_config_path)) + if entry.name != os.path.basename(temp_config_path) + ] == [] + def test_modification_persists_after_reload( self, temp_config_path, minimal_default_config ):