1717
1818#include " common/common.hpp"
1919#include " common/media_io.h"
20+ #include " common/resource_owners.hpp"
2021#include " image_metadata.h"
2122
2223const char * previews_str[] = {
@@ -275,7 +276,7 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
275276}
276277
277278bool load_images_from_dir (const std::string dir,
278- std::vector< sd_image_t > & images,
279+ SDImageVec & images,
279280 int expected_width = 0 ,
280281 int expected_height = 0 ,
281282 int max_image_num = 0 ,
@@ -317,7 +318,7 @@ bool load_images_from_dir(const std::string dir,
317318 3 ,
318319 image_buffer});
319320
320- if (max_image_num > 0 && images.size () >= max_image_num) {
321+ if (max_image_num > 0 && static_cast < int >( images.size () ) >= max_image_num) {
321322 break ;
322323 }
323324 }
@@ -554,39 +555,17 @@ int main(int argc, const char* argv[]) {
554555 }
555556 }
556557
557- bool vae_decode_only = true ;
558- sd_image_t init_image = {0 , 0 , 3 , nullptr };
559- sd_image_t end_image = {0 , 0 , 3 , nullptr };
560- sd_image_t control_image = {0 , 0 , 3 , nullptr };
561- sd_image_t mask_image = {0 , 0 , 1 , nullptr };
562- std::vector<sd_image_t > ref_images;
563- std::vector<sd_image_t > pmid_images;
564- std::vector<sd_image_t > control_frames;
565-
566- auto release_all_resources = [&]() {
567- free (init_image.data );
568- free (end_image.data );
569- free (control_image.data );
570- free (mask_image.data );
571- for (auto image : ref_images) {
572- free (image.data );
573- image.data = nullptr ;
574- }
575- ref_images.clear ();
576- for (auto image : pmid_images) {
577- free (image.data );
578- image.data = nullptr ;
579- }
580- pmid_images.clear ();
581- for (auto image : control_frames) {
582- free (image.data );
583- image.data = nullptr ;
584- }
585- control_frames.clear ();
586- };
558+ bool vae_decode_only = true ;
559+ SDImageOwner init_image ({0 , 0 , 3 , nullptr });
560+ SDImageOwner end_image ({0 , 0 , 3 , nullptr });
561+ SDImageOwner control_image ({0 , 0 , 3 , nullptr });
562+ SDImageOwner mask_image ({0 , 0 , 1 , nullptr });
563+ SDImageVec ref_images;
564+ SDImageVec pmid_images;
565+ SDImageVec control_frames;
587566
588567 auto load_image_and_update_size = [&](const std::string& path,
589- sd_image_t & image,
568+ SDImageOwner & image,
590569 bool resize_image = true ,
591570 int expected_channel = 3 ) -> bool {
592571 int expected_width = 0 ;
@@ -596,13 +575,12 @@ int main(int argc, const char* argv[]) {
596575 expected_height = gen_params.height ;
597576 }
598577
599- if (!load_sd_image_from_file (& image, path.c_str (), expected_width, expected_height, expected_channel)) {
578+ if (!load_sd_image_from_file (image. put () , path.c_str (), expected_width, expected_height, expected_channel)) {
600579 LOG_ERROR (" load image from '%s' failed" , path.c_str ());
601- release_all_resources ();
602580 return false ;
603581 }
604582
605- gen_params.set_width_and_height_if_unset (image.width , image.height );
583+ gen_params.set_width_and_height_if_unset (image.get (). width , image. get () .height );
606584 return true ;
607585 };
608586
@@ -623,47 +601,46 @@ int main(int argc, const char* argv[]) {
623601 if (gen_params.ref_image_paths .size () > 0 ) {
624602 vae_decode_only = false ;
625603 for (auto & path : gen_params.ref_image_paths ) {
626- sd_image_t ref_image = {0 , 0 , 3 , nullptr };
604+ SDImageOwner ref_image ( {0 , 0 , 3 , nullptr }) ;
627605 if (!load_image_and_update_size (path, ref_image, false )) {
628606 return 1 ;
629607 }
630- ref_images.push_back (ref_image);
608+ ref_images.push_back (std::move ( ref_image) );
631609 }
632610 }
633611
634612 if (gen_params.mask_image_path .size () > 0 ) {
635- if (!load_sd_image_from_file (& mask_image,
613+ if (!load_sd_image_from_file (mask_image. put () ,
636614 gen_params.mask_image_path .c_str (),
637615 gen_params.get_resolved_width (),
638616 gen_params.get_resolved_height (),
639617 1 )) {
640618 LOG_ERROR (" load image from '%s' failed" , gen_params.mask_image_path .c_str ());
641- release_all_resources ();
642619 return 1 ;
643620 }
644621 } else {
645- mask_image.data = (uint8_t *)malloc (gen_params.get_resolved_width () * gen_params.get_resolved_height ());
646- if (mask_image.data == nullptr ) {
622+ sd_image_t generated_mask = {0 , 0 , 1 , nullptr };
623+ generated_mask.data = (uint8_t *)malloc (gen_params.get_resolved_width () * gen_params.get_resolved_height ());
624+ if (generated_mask.data == nullptr ) {
647625 LOG_ERROR (" malloc mask image failed" );
648- release_all_resources ();
649626 return 1 ;
650627 }
651- mask_image.width = gen_params.get_resolved_width ();
652- mask_image.height = gen_params.get_resolved_height ();
653- memset (mask_image.data , 255 , gen_params.get_resolved_width () * gen_params.get_resolved_height ());
628+ generated_mask.width = gen_params.get_resolved_width ();
629+ generated_mask.height = gen_params.get_resolved_height ();
630+ memset (generated_mask.data , 255 , gen_params.get_resolved_width () * gen_params.get_resolved_height ());
631+ mask_image.reset (generated_mask);
654632 }
655633
656634 if (gen_params.control_image_path .size () > 0 ) {
657- if (!load_sd_image_from_file (& control_image,
635+ if (!load_sd_image_from_file (control_image. put () ,
658636 gen_params.control_image_path .c_str (),
659637 gen_params.get_resolved_width (),
660638 gen_params.get_resolved_height ())) {
661639 LOG_ERROR (" load image from '%s' failed" , gen_params.control_image_path .c_str ());
662- release_all_resources ();
663640 return 1 ;
664641 }
665642 if (cli_params.canny_preprocess ) { // apply preprocessor
666- preprocess_canny (control_image,
643+ preprocess_canny (control_image. get () ,
667644 0 .08f ,
668645 0 .08f ,
669646 0 .8f ,
@@ -679,7 +656,6 @@ int main(int argc, const char* argv[]) {
679656 gen_params.get_resolved_height (),
680657 gen_params.video_frames ,
681658 cli_params.verbose )) {
682- release_all_resources ();
683659 return 1 ;
684660 }
685661 }
@@ -691,7 +667,6 @@ int main(int argc, const char* argv[]) {
691667 0 ,
692668 0 ,
693669 cli_params.verbose )) {
694- release_all_resources ();
695670 return 1 ;
696671 }
697672 }
@@ -702,39 +677,30 @@ int main(int argc, const char* argv[]) {
702677
703678 sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t (vae_decode_only, true , cli_params.taesd_preview );
704679
705- sd_image_t * results = nullptr ;
706- int num_results = 0 ;
680+ SDImageVec results;
681+ int num_results = 0 ;
707682
708683 if (cli_params.mode == UPSCALE) {
709684 num_results = 1 ;
710- results = (sd_image_t *)calloc (num_results, sizeof (sd_image_t ));
711- if (results == nullptr ) {
712- LOG_INFO (" failed to allocate results array" );
713- release_all_resources ();
714- return 1 ;
715- }
716-
717- results[0 ] = init_image;
718- init_image.data = nullptr ;
685+ results.push_back (init_image.release ());
719686 } else {
720- sd_ctx_t * sd_ctx = new_sd_ctx (&sd_ctx_params);
687+ SDCtxPtr sd_ctx ( new_sd_ctx (&sd_ctx_params) );
721688
722689 if (sd_ctx == nullptr ) {
723690 LOG_INFO (" new_sd_ctx_t failed" );
724- release_all_resources ();
725691 return 1 ;
726692 }
727693
728694 if (gen_params.sample_params .sample_method == SAMPLE_METHOD_COUNT) {
729- gen_params.sample_params .sample_method = sd_get_default_sample_method (sd_ctx);
695+ gen_params.sample_params .sample_method = sd_get_default_sample_method (sd_ctx. get () );
730696 }
731697
732698 if (gen_params.high_noise_sample_params .sample_method == SAMPLE_METHOD_COUNT) {
733- gen_params.high_noise_sample_params .sample_method = sd_get_default_sample_method (sd_ctx);
699+ gen_params.high_noise_sample_params .sample_method = sd_get_default_sample_method (sd_ctx. get () );
734700 }
735701
736702 if (gen_params.sample_params .scheduler == SCHEDULER_COUNT) {
737- gen_params.sample_params .scheduler = sd_get_default_scheduler (sd_ctx, gen_params.sample_params .sample_method );
703+ gen_params.sample_params .scheduler = sd_get_default_scheduler (sd_ctx. get () , gen_params.sample_params .sample_method );
738704 }
739705
740706 if (cli_params.mode == IMG_GEN) {
@@ -744,19 +710,19 @@ int main(int argc, const char* argv[]) {
744710 gen_params.prompt .c_str (),
745711 gen_params.negative_prompt .c_str (),
746712 gen_params.clip_skip ,
747- init_image,
713+ init_image. get () ,
748714 ref_images.data (),
749715 (int )ref_images.size (),
750716 gen_params.auto_resize_ref_image ,
751717 gen_params.increase_ref_index ,
752- mask_image,
718+ mask_image. get () ,
753719 gen_params.get_resolved_width (),
754720 gen_params.get_resolved_height (),
755721 gen_params.sample_params ,
756722 gen_params.strength ,
757723 gen_params.seed ,
758724 gen_params.batch_count ,
759- control_image,
725+ control_image. get () ,
760726 gen_params.control_strength ,
761727 {
762728 pmid_images.data (),
@@ -768,17 +734,17 @@ int main(int argc, const char* argv[]) {
768734 gen_params.cache_params ,
769735 };
770736
771- results = generate_image (sd_ctx, &img_gen_params);
772737 num_results = gen_params.batch_count ;
738+ results.adopt (generate_image (sd_ctx.get (), &img_gen_params), num_results);
773739 } else if (cli_params.mode == VID_GEN) {
774740 sd_vid_gen_params_t vid_gen_params = {
775741 gen_params.lora_vec .data (),
776742 static_cast <uint32_t >(gen_params.lora_vec .size ()),
777743 gen_params.prompt .c_str (),
778744 gen_params.negative_prompt .c_str (),
779745 gen_params.clip_skip ,
780- init_image,
781- end_image,
746+ init_image. get () ,
747+ end_image. get () ,
782748 control_frames.data (),
783749 (int )control_frames.size (),
784750 gen_params.get_resolved_width (),
@@ -794,25 +760,23 @@ int main(int argc, const char* argv[]) {
794760 gen_params.cache_params ,
795761 };
796762
797- results = generate_video (sd_ctx, &vid_gen_params, &num_results);
763+ sd_image_t * generated_video = generate_video (sd_ctx.get (), &vid_gen_params, &num_results);
764+ results.adopt (generated_video, num_results);
798765 }
799766
800- if (results == nullptr ) {
767+ if (! results) {
801768 LOG_ERROR (" generate failed" );
802- free_sd_ctx (sd_ctx);
803769 return 1 ;
804770 }
805-
806- free_sd_ctx (sd_ctx);
807771 }
808772
809773 int upscale_factor = 4 ; // unused for RealESRGAN_x4plus_anime_6B.pth
810774 if (ctx_params.esrgan_path .size () > 0 && gen_params.upscale_repeats > 0 ) {
811- upscaler_ctx_t * upscaler_ctx = new_upscaler_ctx (ctx_params.esrgan_path .c_str (),
812- ctx_params.offload_params_to_cpu ,
813- ctx_params.diffusion_conv_direct ,
814- ctx_params.n_threads ,
815- gen_params.upscale_tile_size );
775+ UpscalerCtxPtr upscaler_ctx ( new_upscaler_ctx (ctx_params.esrgan_path .c_str (),
776+ ctx_params.offload_params_to_cpu ,
777+ ctx_params.diffusion_conv_direct ,
778+ ctx_params.n_threads ,
779+ gen_params.upscale_tile_size ) );
816780
817781 if (upscaler_ctx == nullptr ) {
818782 LOG_ERROR (" new_upscaler_ctx failed" );
@@ -821,32 +785,24 @@ int main(int argc, const char* argv[]) {
821785 if (results[i].data == nullptr ) {
822786 continue ;
823787 }
824- sd_image_t current_image = results[i];
788+ SDImageOwner current_image (results[i]);
789+ results[i] = {0 , 0 , 0 , nullptr };
825790 for (int u = 0 ; u < gen_params.upscale_repeats ; ++u) {
826- sd_image_t upscaled_image = upscale (upscaler_ctx, current_image, upscale_factor);
827- if (upscaled_image.data == nullptr ) {
791+ SDImageOwner upscaled_image ( upscale (upscaler_ctx. get () , current_image. get () , upscale_factor) );
792+ if (upscaled_image.get (). data == nullptr ) {
828793 LOG_ERROR (" upscale failed" );
829794 break ;
830795 }
831- free (current_image.data );
832- current_image = upscaled_image;
796+ current_image = std::move (upscaled_image);
833797 }
834- results[i] = current_image; // Set the final upscaled image as the result
798+ results[i] = current_image. release () ; // Set the final upscaled image as the result
835799 }
836800 }
837801 }
838802
839- if (!save_results (cli_params, ctx_params, gen_params, results, num_results)) {
803+ if (!save_results (cli_params, ctx_params, gen_params, results. data () , num_results)) {
840804 return 1 ;
841805 }
842806
843- for (int i = 0 ; i < num_results; i++) {
844- free (results[i].data );
845- results[i].data = nullptr ;
846- }
847- free (results);
848-
849- release_all_resources ();
850-
851807 return 0 ;
852808}
0 commit comments