diff --git a/examples/server/async_jobs.cpp b/examples/server/async_jobs.cpp index 39c47cfaa..b3f714141 100644 --- a/examples/server/async_jobs.cpp +++ b/examples/server/async_jobs.cpp @@ -155,51 +155,10 @@ bool execute_img_gen_job(ServerRuntime& runtime, std::string& error_message) { sd_img_gen_params_t params = job.img_gen.to_sd_img_gen_params_t(); - SDImageVec results; - int num_results = 0; - - { - std::lock_guard lock(*runtime.sd_ctx_mutex); - sd_image_t* raw_results = generate_image(runtime.sd_ctx, ¶ms); - num_results = params.batch_count; - results.adopt(raw_results, num_results); - } - - if (results.empty() || num_results <= 0) { - error_message = "generate_image returned no results"; + output_images = generate_and_encode(runtime, job.img_gen, error_message); + if (!error_message.empty()) { return false; } - - EncodedImageFormat encoded_format = EncodedImageFormat::PNG; - if (job.img_gen.output_format == "jpeg") { - encoded_format = EncodedImageFormat::JPEG; - } else if (job.img_gen.output_format == "webp") { - encoded_format = EncodedImageFormat::WEBP; - } - - for (int i = 0; i < num_results; ++i) { - if (results[i].data == nullptr) { - continue; - } - - const std::string metadata = job.img_gen.gen_params.embed_image_metadata - ? get_image_params(*runtime.ctx_params, - job.img_gen.gen_params, - job.img_gen.gen_params.seed + i) - : ""; - auto image_bytes = encode_image_to_vector(encoded_format, - results[i].data, - results[i].width, - results[i].height, - results[i].channel, - metadata, - job.img_gen.output_compression); - if (image_bytes.empty()) { - continue; - } - output_images.push_back(base64_encode(image_bytes)); - } - if (output_images.empty()) { error_message = "generate_image returned empty encoded outputs"; return false; diff --git a/examples/server/frontend b/examples/server/frontend index 740475a7a..1a34176cd 160000 --- a/examples/server/frontend +++ b/examples/server/frontend @@ -1 +1 @@ -Subproject commit 740475a7a6794dc07fb23e8ec5dc56e7e80aa8c1 +Subproject commit 1a34176cd6d39ad3a226b2b69047e71f6797f6bc diff --git a/examples/server/routes_openai.cpp b/examples/server/routes_openai.cpp index af1210459..62df8aaab 100644 --- a/examples/server/routes_openai.cpp +++ b/examples/server/routes_openai.cpp @@ -263,8 +263,8 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) { LOG_DEBUG("%s\n", request.gen_params.to_string().c_str()); - SDImageVec results; - if (!execute_sync_img_gen_request(*runtime, request, results, error_message)) { + auto strings = generate_and_encode(*runtime, request, error_message); + if (!error_message.empty()) { res.status = 500; res.set_content(json({{"error", error_message}}).dump(), "application/json"); return; @@ -274,35 +274,10 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) { out["created"] = static_cast(std::time(nullptr)); out["data"] = json::array(); out["output_format"] = request.output_format; - - for (int i = 0; i < request.gen_params.batch_count; ++i) { - if (results[i].data == nullptr) { - continue; - } - std::string params = request.gen_params.embed_image_metadata - ? get_image_params(*runtime->ctx_params, - request.gen_params, - request.gen_params.seed + i) - : ""; - auto image_bytes = encode_image_to_vector(request.output_format == "jpeg" - ? EncodedImageFormat::JPEG - : request.output_format == "webp" - ? EncodedImageFormat::WEBP - : EncodedImageFormat::PNG, - results[i].data, - results[i].width, - results[i].height, - results[i].channel, - params, - request.output_compression); - if (image_bytes.empty()) { - LOG_ERROR("write image to mem failed"); - continue; - } - + for (auto& str: strings) { json item; - item["b64_json"] = base64_encode(image_bytes); - out["data"].push_back(item); + item["b64_json"] = std::move(str); + out["data"].push_back(std::move(item)); } res.set_content(out.dump(), "application/json"); @@ -329,8 +304,8 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) { LOG_DEBUG("%s\n", request.gen_params.to_string().c_str()); - SDImageVec results; - if (!execute_sync_img_gen_request(*runtime, request, results, error_message)) { + auto strings = generate_and_encode(*runtime, request, error_message); + if (!error_message.empty()) { res.status = 500; res.set_content(json({{"error", error_message}}).dump(), "application/json"); return; @@ -340,26 +315,10 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) { out["created"] = static_cast(std::time(nullptr)); out["data"] = json::array(); out["output_format"] = request.output_format; - - for (int i = 0; i < request.gen_params.batch_count; ++i) { - if (results[i].data == nullptr) { - continue; - } - std::string params = request.gen_params.embed_image_metadata - ? get_image_params(*runtime->ctx_params, - request.gen_params, - request.gen_params.seed + i) - : ""; - auto image_bytes = encode_image_to_vector(request.output_format == "jpeg" ? EncodedImageFormat::JPEG : EncodedImageFormat::PNG, - results[i].data, - results[i].width, - results[i].height, - results[i].channel, - params, - request.output_compression); + for (auto& str: strings) { json item; - item["b64_json"] = base64_encode(image_bytes); - out["data"].push_back(item); + item["b64_json"] = std::move(str); + out["data"].push_back(std::move(item)); } res.set_content(out.dump(), "application/json"); diff --git a/examples/server/routes_sdapi.cpp b/examples/server/routes_sdapi.cpp index ca6661c0b..b8b2d740d 100644 --- a/examples/server/routes_sdapi.cpp +++ b/examples/server/routes_sdapi.cpp @@ -258,52 +258,16 @@ void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) { LOG_DEBUG("%s\n", request.gen_params.to_string().c_str()); - sd_img_gen_params_t img_gen_params = request.to_sd_img_gen_params_t(); - SDImageVec results; - int num_results = 0; - - { - std::lock_guard lock(*runtime->sd_ctx_mutex); - sd_image_t* raw_results = generate_image(runtime->sd_ctx, &img_gen_params); - num_results = request.gen_params.batch_count; - results.adopt(raw_results, num_results); - } - - if (results.empty()) { - res.status = 500; - res.set_content(R"({"error":"generate_image returned no results"})", "application/json"); - return; - } + std::string error_str; + auto strings = generate_and_encode(*runtime, request, error_str); + if (!error_str.empty()) + throw std::runtime_error(error_str); json out; out["images"] = json::array(); out["parameters"] = j; out["info"] = ""; - - for (int i = 0; i < num_results; ++i) { - if (results[i].data == nullptr) { - continue; - } - - std::string params = request.gen_params.embed_image_metadata - ? get_image_params(*runtime->ctx_params, - request.gen_params, - request.gen_params.seed + i) - : ""; - auto image_bytes = encode_image_to_vector(EncodedImageFormat::PNG, - results[i].data, - results[i].width, - results[i].height, - results[i].channel, - params); - - if (image_bytes.empty()) { - LOG_ERROR("write image to mem failed"); - continue; - } - - out["images"].push_back(base64_encode(image_bytes)); - } + out["images"] = std::move(strings); res.set_content(out.dump(), "application/json"); res.status = 200; diff --git a/examples/server/runtime.cpp b/examples/server/runtime.cpp index c29799e3a..d1c4d3e5b 100644 --- a/examples/server/runtime.cpp +++ b/examples/server/runtime.cpp @@ -10,6 +10,7 @@ #include "common/common.h" #include "common/log.h" +#include "common/media_io.h" namespace fs = std::filesystem; @@ -184,3 +185,48 @@ int64_t unix_timestamp_now() { std::chrono::system_clock::now().time_since_epoch()) .count(); } + +std::vector generate_and_encode(ServerRuntime& runtime, ImgGenJobRequest& request, std::string& error_message) { + std::vector strings; + + sd_img_gen_params_t img_gen_params = request.to_sd_img_gen_params_t(); + SDImageVec results; + int num_results = 0; + { + std::lock_guard lock(*runtime.sd_ctx_mutex); + sd_image_t* raw_results = generate_image(runtime.sd_ctx, &img_gen_params); + num_results = request.gen_params.batch_count; + results.adopt(raw_results, num_results); + } + if (results.empty()) { + error_message = "generate_image returned no results"; + } else { + for (int i = 0; i < request.gen_params.batch_count; ++i) { + if (results[i].data == nullptr) { + continue; + } + std::string params = request.gen_params.embed_image_metadata + ? get_image_params(*runtime.ctx_params, + request.gen_params, + request.gen_params.seed + i) + : ""; + auto image_bytes = encode_image_to_vector(request.output_format == "jpeg" + ? EncodedImageFormat::JPEG + : request.output_format == "webp" + ? EncodedImageFormat::WEBP + : EncodedImageFormat::PNG, + results[i].data, + results[i].width, + results[i].height, + results[i].channel, + params, + request.output_compression); + if (image_bytes.empty()) { + LOG_ERROR("write image to mem failed"); + continue; + } + strings.push_back(base64_encode(image_bytes)); + } + } + return strings; +} diff --git a/examples/server/runtime.h b/examples/server/runtime.h index 65e932439..da310f7cb 100644 --- a/examples/server/runtime.h +++ b/examples/server/runtime.h @@ -68,3 +68,5 @@ bool assign_output_options(ImgGenJobRequest& request, void refresh_lora_cache(ServerRuntime& rt); std::string get_lora_full_path(ServerRuntime& rt, const std::string& path); int64_t unix_timestamp_now(); + +std::vector generate_and_encode(ServerRuntime& runtime, ImgGenJobRequest& request, std::string& error_message);