diff --git a/src/xc_integrator/integrator_util/onedft_util.cxx b/src/xc_integrator/integrator_util/onedft_util.cxx index 13dff118..20593ab4 100644 --- a/src/xc_integrator/integrator_util/onedft_util.cxx +++ b/src/xc_integrator/integrator_util/onedft_util.cxx @@ -52,8 +52,26 @@ std::string map_model(const std::string& model, torch::DeviceType device) { } else if (model == "LDA") { return model_path + "/lda.fun"; } else if (model == "SKALA") { - GAUXC_GENERIC_EXCEPTION("To use the Skala functional, specify a local checkpoint path."); + const std::string install_path = std::string(GAUXC_ONEDFT_MODEL_PATH_INSTALL); + std::vector search_paths; + search_paths.push_back(model_path); + if (std::filesystem::exists(install_path) && install_path != model_path) { + search_paths.push_back(install_path); + } + for (const auto& path : search_paths) { + if (device == torch::kCUDA && std::filesystem::exists(path + "/skala-1.1-cuda.fun")) { + return path + "/skala-1.1-cuda.fun"; + } + if (std::filesystem::exists(path + "/skala-1.1.fun")) { + return path + "/skala-1.1.fun"; + } + } + GAUXC_GENERIC_EXCEPTION("To use the Skala functional, install skala-1.1.fun or specify a local checkpoint path."); } else { + const std::string install_path = std::string(GAUXC_ONEDFT_MODEL_PATH_INSTALL); + if (std::filesystem::exists(install_path + "/" + model)) { + return install_path + "/" + model; + } GAUXC_GENERIC_EXCEPTION("Model " + model + " not found in " + model_path); } }