diff --git a/main.py b/main.py index 9f62fdf410..907d922188 100644 --- a/main.py +++ b/main.py @@ -79,25 +79,38 @@ async def check_dashboard_files(webui_dir: str | None = None): data_dist_path = os.path.join(get_astrbot_data_path(), "dist") if os.path.exists(data_dist_path): - v = await get_dashboard_version() - if should_use_bundled_dashboard_dist(data_dist_path, VERSION): - bundled_dist = get_bundled_dashboard_dist_path() - logger.info( - "Using bundled WebUI because data/dist is older than core version v%s.", - VERSION, + index_html = os.path.join(data_dist_path, "index.html") + assets_dir = os.path.join(data_dist_path, "assets") + version_file = os.path.join(assets_dir, "version") + if not ( + os.path.isfile(index_html) + and os.path.isdir(assets_dir) + and os.path.isfile(version_file) + ): + logger.warning( + "WebUI directory is incomplete: %s. Downloading a fresh copy.", + data_dist_path, ) - return str(bundled_dist) - if v is not None: - # 存在文件 - if v == f"v{VERSION}": - logger.info("WebUI is up to date.") - else: - logger.warning( - "WebUI version mismatch: %s, expected v%s.", - v, + else: + v = await get_dashboard_version() + if should_use_bundled_dashboard_dist(data_dist_path, VERSION): + bundled_dist = get_bundled_dashboard_dist_path() + logger.info( + "Using bundled WebUI because data/dist is older than core version v%s.", VERSION, ) - return data_dist_path + return str(bundled_dist) + if v is not None: + # 存在文件 + if v == f"v{VERSION}": + logger.info("WebUI is up to date.") + else: + logger.warning( + "WebUI version mismatch: %s, expected v%s.", + v, + VERSION, + ) + return data_dist_path logger.info( "Downloading WebUI. If it fails, download dist.zip from https://github.com/AstrBotDevs/AstrBot/releases/latest and extract dist to data/.", diff --git a/tests/test_main.py b/tests/test_main.py index ae60366315..1b681c27b6 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -10,7 +10,7 @@ import pytest from astrbot.core.utils.io import should_use_bundled_dashboard_dist -from main import check_dashboard_files, check_env +from main import VERSION, check_dashboard_files, check_env class _version_info: @@ -158,6 +158,8 @@ async def test_check_dashboard_files_exists_and_version_match(monkeypatch): """Tests that dashboard is not downloaded when it exists and version matches.""" # Mock os.path.exists to return True monkeypatch.setattr(os.path, "exists", lambda x: True) + monkeypatch.setattr(os.path, "isfile", lambda x: True) + monkeypatch.setattr(os.path, "isdir", lambda x: True) # Mock get_dashboard_version to return the current version with mock.patch("main.get_dashboard_version") as mock_get_version: @@ -176,6 +178,8 @@ async def test_check_dashboard_files_exists_and_version_match(monkeypatch): async def test_check_dashboard_files_exists_but_version_mismatch(monkeypatch): """Tests that a warning is logged when dashboard version mismatches.""" monkeypatch.setattr(os.path, "exists", lambda x: True) + monkeypatch.setattr(os.path, "isfile", lambda x: True) + monkeypatch.setattr(os.path, "isdir", lambda x: True) with mock.patch( "main.get_dashboard_version", mock.AsyncMock(return_value="v0.0.1") @@ -225,6 +229,9 @@ async def test_check_dashboard_files_uses_bundled_dist_when_data_dist_is_stale( data_dist = data_dir / "dist" bundled_dist = tmp_path / "bundled-dist" data_dist.mkdir(parents=True) + (data_dist / "index.html").write_text("", encoding="utf-8") + (data_dist / "assets").mkdir() + (data_dist / "assets" / "version").write_text("v0.0.1", encoding="utf-8") bundled_dist.mkdir() with mock.patch("main.get_astrbot_data_path", return_value=str(data_dir)): @@ -245,6 +252,25 @@ async def test_check_dashboard_files_uses_bundled_dist_when_data_dist_is_stale( mock_download.assert_not_called() +@pytest.mark.asyncio +async def test_check_dashboard_files_redownloads_incomplete_data_dist(tmp_path): + """Tests that a partial data/dist does not skip dashboard download.""" + data_dir = tmp_path / "data" + data_dist = data_dir / "dist" + data_dist.mkdir(parents=True) + + with mock.patch("main.get_astrbot_data_path", return_value=str(data_dir)): + with mock.patch("main.get_dashboard_version") as mock_get_version: + with mock.patch( + "main.download_dashboard", mock.AsyncMock() + ) as mock_download: + result = await check_dashboard_files() + + assert result == str(data_dist) + mock_get_version.assert_not_called() + mock_download.assert_awaited_once_with(version=f"v{VERSION}", latest=False) + + @pytest.mark.asyncio async def test_check_dashboard_files_with_webui_dir_arg(monkeypatch): """Tests that providing a valid webui_dir skips all checks."""