Skip to content

Commit 359eb8b

Browse files
authored
refactor: apply RAII ownership to examples (#1392)
1 parent 7397dda commit 359eb8b

6 files changed

Lines changed: 506 additions & 354 deletions

File tree

examples/cli/main.cpp

Lines changed: 54 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include "common/common.hpp"
1919
#include "common/media_io.h"
20+
#include "common/resource_owners.hpp"
2021
#include "image_metadata.h"
2122

2223
const char* previews_str[] = {
@@ -275,7 +276,7 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
275276
}
276277

277278
bool load_images_from_dir(const std::string dir,
278-
std::vector<sd_image_t>& images,
279+
SDImageVec& images,
279280
int expected_width = 0,
280281
int expected_height = 0,
281282
int max_image_num = 0,
@@ -317,7 +318,7 @@ bool load_images_from_dir(const std::string dir,
317318
3,
318319
image_buffer});
319320

320-
if (max_image_num > 0 && images.size() >= max_image_num) {
321+
if (max_image_num > 0 && static_cast<int>(images.size()) >= max_image_num) {
321322
break;
322323
}
323324
}
@@ -554,39 +555,17 @@ int main(int argc, const char* argv[]) {
554555
}
555556
}
556557

557-
bool vae_decode_only = true;
558-
sd_image_t init_image = {0, 0, 3, nullptr};
559-
sd_image_t end_image = {0, 0, 3, nullptr};
560-
sd_image_t control_image = {0, 0, 3, nullptr};
561-
sd_image_t mask_image = {0, 0, 1, nullptr};
562-
std::vector<sd_image_t> ref_images;
563-
std::vector<sd_image_t> pmid_images;
564-
std::vector<sd_image_t> control_frames;
565-
566-
auto release_all_resources = [&]() {
567-
free(init_image.data);
568-
free(end_image.data);
569-
free(control_image.data);
570-
free(mask_image.data);
571-
for (auto image : ref_images) {
572-
free(image.data);
573-
image.data = nullptr;
574-
}
575-
ref_images.clear();
576-
for (auto image : pmid_images) {
577-
free(image.data);
578-
image.data = nullptr;
579-
}
580-
pmid_images.clear();
581-
for (auto image : control_frames) {
582-
free(image.data);
583-
image.data = nullptr;
584-
}
585-
control_frames.clear();
586-
};
558+
bool vae_decode_only = true;
559+
SDImageOwner init_image({0, 0, 3, nullptr});
560+
SDImageOwner end_image({0, 0, 3, nullptr});
561+
SDImageOwner control_image({0, 0, 3, nullptr});
562+
SDImageOwner mask_image({0, 0, 1, nullptr});
563+
SDImageVec ref_images;
564+
SDImageVec pmid_images;
565+
SDImageVec control_frames;
587566

588567
auto load_image_and_update_size = [&](const std::string& path,
589-
sd_image_t& image,
568+
SDImageOwner& image,
590569
bool resize_image = true,
591570
int expected_channel = 3) -> bool {
592571
int expected_width = 0;
@@ -596,13 +575,12 @@ int main(int argc, const char* argv[]) {
596575
expected_height = gen_params.height;
597576
}
598577

599-
if (!load_sd_image_from_file(&image, path.c_str(), expected_width, expected_height, expected_channel)) {
578+
if (!load_sd_image_from_file(image.put(), path.c_str(), expected_width, expected_height, expected_channel)) {
600579
LOG_ERROR("load image from '%s' failed", path.c_str());
601-
release_all_resources();
602580
return false;
603581
}
604582

605-
gen_params.set_width_and_height_if_unset(image.width, image.height);
583+
gen_params.set_width_and_height_if_unset(image.get().width, image.get().height);
606584
return true;
607585
};
608586

@@ -623,47 +601,46 @@ int main(int argc, const char* argv[]) {
623601
if (gen_params.ref_image_paths.size() > 0) {
624602
vae_decode_only = false;
625603
for (auto& path : gen_params.ref_image_paths) {
626-
sd_image_t ref_image = {0, 0, 3, nullptr};
604+
SDImageOwner ref_image({0, 0, 3, nullptr});
627605
if (!load_image_and_update_size(path, ref_image, false)) {
628606
return 1;
629607
}
630-
ref_images.push_back(ref_image);
608+
ref_images.push_back(std::move(ref_image));
631609
}
632610
}
633611

634612
if (gen_params.mask_image_path.size() > 0) {
635-
if (!load_sd_image_from_file(&mask_image,
613+
if (!load_sd_image_from_file(mask_image.put(),
636614
gen_params.mask_image_path.c_str(),
637615
gen_params.get_resolved_width(),
638616
gen_params.get_resolved_height(),
639617
1)) {
640618
LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str());
641-
release_all_resources();
642619
return 1;
643620
}
644621
} else {
645-
mask_image.data = (uint8_t*)malloc(gen_params.get_resolved_width() * gen_params.get_resolved_height());
646-
if (mask_image.data == nullptr) {
622+
sd_image_t generated_mask = {0, 0, 1, nullptr};
623+
generated_mask.data = (uint8_t*)malloc(gen_params.get_resolved_width() * gen_params.get_resolved_height());
624+
if (generated_mask.data == nullptr) {
647625
LOG_ERROR("malloc mask image failed");
648-
release_all_resources();
649626
return 1;
650627
}
651-
mask_image.width = gen_params.get_resolved_width();
652-
mask_image.height = gen_params.get_resolved_height();
653-
memset(mask_image.data, 255, gen_params.get_resolved_width() * gen_params.get_resolved_height());
628+
generated_mask.width = gen_params.get_resolved_width();
629+
generated_mask.height = gen_params.get_resolved_height();
630+
memset(generated_mask.data, 255, gen_params.get_resolved_width() * gen_params.get_resolved_height());
631+
mask_image.reset(generated_mask);
654632
}
655633

656634
if (gen_params.control_image_path.size() > 0) {
657-
if (!load_sd_image_from_file(&control_image,
635+
if (!load_sd_image_from_file(control_image.put(),
658636
gen_params.control_image_path.c_str(),
659637
gen_params.get_resolved_width(),
660638
gen_params.get_resolved_height())) {
661639
LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str());
662-
release_all_resources();
663640
return 1;
664641
}
665642
if (cli_params.canny_preprocess) { // apply preprocessor
666-
preprocess_canny(control_image,
643+
preprocess_canny(control_image.get(),
667644
0.08f,
668645
0.08f,
669646
0.8f,
@@ -679,7 +656,6 @@ int main(int argc, const char* argv[]) {
679656
gen_params.get_resolved_height(),
680657
gen_params.video_frames,
681658
cli_params.verbose)) {
682-
release_all_resources();
683659
return 1;
684660
}
685661
}
@@ -691,7 +667,6 @@ int main(int argc, const char* argv[]) {
691667
0,
692668
0,
693669
cli_params.verbose)) {
694-
release_all_resources();
695670
return 1;
696671
}
697672
}
@@ -702,39 +677,30 @@ int main(int argc, const char* argv[]) {
702677

703678
sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(vae_decode_only, true, cli_params.taesd_preview);
704679

705-
sd_image_t* results = nullptr;
706-
int num_results = 0;
680+
SDImageVec results;
681+
int num_results = 0;
707682

708683
if (cli_params.mode == UPSCALE) {
709684
num_results = 1;
710-
results = (sd_image_t*)calloc(num_results, sizeof(sd_image_t));
711-
if (results == nullptr) {
712-
LOG_INFO("failed to allocate results array");
713-
release_all_resources();
714-
return 1;
715-
}
716-
717-
results[0] = init_image;
718-
init_image.data = nullptr;
685+
results.push_back(init_image.release());
719686
} else {
720-
sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);
687+
SDCtxPtr sd_ctx(new_sd_ctx(&sd_ctx_params));
721688

722689
if (sd_ctx == nullptr) {
723690
LOG_INFO("new_sd_ctx_t failed");
724-
release_all_resources();
725691
return 1;
726692
}
727693

728694
if (gen_params.sample_params.sample_method == SAMPLE_METHOD_COUNT) {
729-
gen_params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
695+
gen_params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx.get());
730696
}
731697

732698
if (gen_params.high_noise_sample_params.sample_method == SAMPLE_METHOD_COUNT) {
733-
gen_params.high_noise_sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
699+
gen_params.high_noise_sample_params.sample_method = sd_get_default_sample_method(sd_ctx.get());
734700
}
735701

736702
if (gen_params.sample_params.scheduler == SCHEDULER_COUNT) {
737-
gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx, gen_params.sample_params.sample_method);
703+
gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx.get(), gen_params.sample_params.sample_method);
738704
}
739705

740706
if (cli_params.mode == IMG_GEN) {
@@ -744,19 +710,19 @@ int main(int argc, const char* argv[]) {
744710
gen_params.prompt.c_str(),
745711
gen_params.negative_prompt.c_str(),
746712
gen_params.clip_skip,
747-
init_image,
713+
init_image.get(),
748714
ref_images.data(),
749715
(int)ref_images.size(),
750716
gen_params.auto_resize_ref_image,
751717
gen_params.increase_ref_index,
752-
mask_image,
718+
mask_image.get(),
753719
gen_params.get_resolved_width(),
754720
gen_params.get_resolved_height(),
755721
gen_params.sample_params,
756722
gen_params.strength,
757723
gen_params.seed,
758724
gen_params.batch_count,
759-
control_image,
725+
control_image.get(),
760726
gen_params.control_strength,
761727
{
762728
pmid_images.data(),
@@ -768,17 +734,17 @@ int main(int argc, const char* argv[]) {
768734
gen_params.cache_params,
769735
};
770736

771-
results = generate_image(sd_ctx, &img_gen_params);
772737
num_results = gen_params.batch_count;
738+
results.adopt(generate_image(sd_ctx.get(), &img_gen_params), num_results);
773739
} else if (cli_params.mode == VID_GEN) {
774740
sd_vid_gen_params_t vid_gen_params = {
775741
gen_params.lora_vec.data(),
776742
static_cast<uint32_t>(gen_params.lora_vec.size()),
777743
gen_params.prompt.c_str(),
778744
gen_params.negative_prompt.c_str(),
779745
gen_params.clip_skip,
780-
init_image,
781-
end_image,
746+
init_image.get(),
747+
end_image.get(),
782748
control_frames.data(),
783749
(int)control_frames.size(),
784750
gen_params.get_resolved_width(),
@@ -794,25 +760,23 @@ int main(int argc, const char* argv[]) {
794760
gen_params.cache_params,
795761
};
796762

797-
results = generate_video(sd_ctx, &vid_gen_params, &num_results);
763+
sd_image_t* generated_video = generate_video(sd_ctx.get(), &vid_gen_params, &num_results);
764+
results.adopt(generated_video, num_results);
798765
}
799766

800-
if (results == nullptr) {
767+
if (!results) {
801768
LOG_ERROR("generate failed");
802-
free_sd_ctx(sd_ctx);
803769
return 1;
804770
}
805-
806-
free_sd_ctx(sd_ctx);
807771
}
808772

809773
int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
810774
if (ctx_params.esrgan_path.size() > 0 && gen_params.upscale_repeats > 0) {
811-
upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(ctx_params.esrgan_path.c_str(),
812-
ctx_params.offload_params_to_cpu,
813-
ctx_params.diffusion_conv_direct,
814-
ctx_params.n_threads,
815-
gen_params.upscale_tile_size);
775+
UpscalerCtxPtr upscaler_ctx(new_upscaler_ctx(ctx_params.esrgan_path.c_str(),
776+
ctx_params.offload_params_to_cpu,
777+
ctx_params.diffusion_conv_direct,
778+
ctx_params.n_threads,
779+
gen_params.upscale_tile_size));
816780

817781
if (upscaler_ctx == nullptr) {
818782
LOG_ERROR("new_upscaler_ctx failed");
@@ -821,32 +785,24 @@ int main(int argc, const char* argv[]) {
821785
if (results[i].data == nullptr) {
822786
continue;
823787
}
824-
sd_image_t current_image = results[i];
788+
SDImageOwner current_image(results[i]);
789+
results[i] = {0, 0, 0, nullptr};
825790
for (int u = 0; u < gen_params.upscale_repeats; ++u) {
826-
sd_image_t upscaled_image = upscale(upscaler_ctx, current_image, upscale_factor);
827-
if (upscaled_image.data == nullptr) {
791+
SDImageOwner upscaled_image(upscale(upscaler_ctx.get(), current_image.get(), upscale_factor));
792+
if (upscaled_image.get().data == nullptr) {
828793
LOG_ERROR("upscale failed");
829794
break;
830795
}
831-
free(current_image.data);
832-
current_image = upscaled_image;
796+
current_image = std::move(upscaled_image);
833797
}
834-
results[i] = current_image; // Set the final upscaled image as the result
798+
results[i] = current_image.release(); // Set the final upscaled image as the result
835799
}
836800
}
837801
}
838802

839-
if (!save_results(cli_params, ctx_params, gen_params, results, num_results)) {
803+
if (!save_results(cli_params, ctx_params, gen_params, results.data(), num_results)) {
840804
return 1;
841805
}
842806

843-
for (int i = 0; i < num_results; i++) {
844-
free(results[i].data);
845-
results[i].data = nullptr;
846-
}
847-
free(results);
848-
849-
release_all_resources();
850-
851807
return 0;
852808
}

examples/common/common.hpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ namespace fs = std::filesystem;
2020
#endif // _WIN32
2121

2222
#include "log.h"
23+
#include "resource_owners.hpp"
2324
#include "stable-diffusion.h"
2425

2526
#define SAFE_STR(s) ((s) ? (s) : "")
@@ -1751,8 +1752,8 @@ struct SDGenerationParams {
17511752
}
17521753

17531754
std::string to_string() const {
1754-
char* sample_params_str = sd_sample_params_to_str(&sample_params);
1755-
char* high_noise_sample_params_str = sd_sample_params_to_str(&high_noise_sample_params);
1755+
FreeUniquePtr<char> sample_params_str(sd_sample_params_to_str(&sample_params));
1756+
FreeUniquePtr<char> high_noise_sample_params_str(sd_sample_params_to_str(&high_noise_sample_params));
17561757

17571758
std::ostringstream lora_ss;
17581759
lora_ss << "{\n";
@@ -1801,9 +1802,9 @@ struct SDGenerationParams {
18011802
<< " pm_id_embed_path: \"" << pm_id_embed_path << "\",\n"
18021803
<< " pm_style_strength: " << pm_style_strength << ",\n"
18031804
<< " skip_layers: " << vec_to_string(skip_layers) << ",\n"
1804-
<< " sample_params: " << sample_params_str << ",\n"
1805+
<< " sample_params: " << SAFE_STR(sample_params_str.get()) << ",\n"
18051806
<< " high_noise_skip_layers: " << vec_to_string(high_noise_skip_layers) << ",\n"
1806-
<< " high_noise_sample_params: " << high_noise_sample_params_str << ",\n"
1807+
<< " high_noise_sample_params: " << SAFE_STR(high_noise_sample_params_str.get()) << ",\n"
18071808
<< " custom_sigmas: " << vec_to_string(custom_sigmas) << ",\n"
18081809
<< " cache_mode: \"" << cache_mode << "\",\n"
18091810
<< " cache_option: \"" << cache_option << "\",\n"
@@ -1829,8 +1830,6 @@ struct SDGenerationParams {
18291830
<< vae_tiling_params.rel_size_x << ", "
18301831
<< vae_tiling_params.rel_size_y << " },\n"
18311832
<< "}";
1832-
free(sample_params_str);
1833-
free(high_noise_sample_params_str);
18341833
return oss.str();
18351834
}
18361835
};

0 commit comments

Comments
 (0)