@@ -533,7 +533,7 @@ class AutoEncoderKLModel : public GGMLBlock {
533533 const std::string& prefix = " " )
534534 : version(version), decode_only(decode_only), use_video_decoder(use_video_decoder) {
535535 if (sd_version_is_dit (version)) {
536- if (sd_version_is_flux2 (version)) {
536+ if (sd_version_uses_flux2_vae (version)) {
537537 dd_config.z_channels = 32 ;
538538 embed_dim = 32 ;
539539 } else {
@@ -578,7 +578,7 @@ class AutoEncoderKLModel : public GGMLBlock {
578578
579579 ggml_tensor* decode (GGMLRunnerContext* ctx, ggml_tensor* z) {
580580 // z: [N, z_channels, h, w]
581- if (sd_version_is_flux2 (version)) {
581+ if (sd_version_uses_flux2_vae (version)) {
582582 // [N, C*p*p, h, w] -> [N, C, h*p, w*p]
583583 int64_t p = 2 ;
584584
@@ -617,7 +617,7 @@ class AutoEncoderKLModel : public GGMLBlock {
617617 auto quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks[" quant_conv" ]);
618618 z = quant_conv->forward (ctx, z); // [N, 2*embed_dim, h/8, w/8]
619619 }
620- if (sd_version_is_flux2 (version)) {
620+ if (sd_version_uses_flux2_vae (version)) {
621621 z = ggml_ext_chunk (ctx->ggml_ctx , z, 2 , 2 )[0 ];
622622
623623 // [N, C, H, W] -> [N, C*p*p, H/p, W/p]
@@ -640,7 +640,7 @@ class AutoEncoderKLModel : public GGMLBlock {
640640
641641 int get_encoder_output_channels () {
642642 int factor = dd_config.double_z ? 2 : 1 ;
643- if (sd_version_is_flux2 (version)) {
643+ if (sd_version_uses_flux2_vae (version)) {
644644 return dd_config.z_channels * 4 ;
645645 }
646646 return dd_config.z_channels * factor;
@@ -673,7 +673,7 @@ struct AutoEncoderKL : public VAE {
673673 } else if (sd_version_is_flux (version) || sd_version_is_z_image (version)) {
674674 scale_factor = 0 .3611f ;
675675 shift_factor = 0 .1159f ;
676- } else if (sd_version_is_flux2 (version)) {
676+ } else if (sd_version_uses_flux2_vae (version)) {
677677 scale_factor = 1 .0f ;
678678 shift_factor = 0 .f ;
679679 }
@@ -747,7 +747,7 @@ struct AutoEncoderKL : public VAE {
747747 }
748748
749749 sd::Tensor<float > vae_output_to_latents (const sd::Tensor<float >& vae_output, std::shared_ptr<RNG> rng) override {
750- if (sd_version_is_flux2 (version)) {
750+ if (sd_version_uses_flux2_vae (version)) {
751751 return vae_output;
752752 } else if (version == VERSION_SD1_PIX2PIX) {
753753 return sd::ops::chunk (vae_output, 2 , 2 )[0 ];
@@ -758,7 +758,7 @@ struct AutoEncoderKL : public VAE {
758758
759759 std::pair<sd::Tensor<float >, sd::Tensor<float >> get_latents_mean_std (const sd::Tensor<float >& latents, int channel_dim) {
760760 GGML_ASSERT (channel_dim >= 0 && static_cast <size_t >(channel_dim) < static_cast <size_t >(latents.dim ()));
761- if (sd_version_is_flux2 (version)) {
761+ if (sd_version_uses_flux2_vae (version)) {
762762 GGML_ASSERT (latents.shape ()[channel_dim] == 128 );
763763 std::vector<int64_t > stats_shape (static_cast <size_t >(latents.dim ()), 1 );
764764 stats_shape[static_cast <size_t >(channel_dim)] = latents.shape ()[channel_dim];
@@ -804,7 +804,7 @@ struct AutoEncoderKL : public VAE {
804804 }
805805
806806 sd::Tensor<float > diffusion_to_vae_latents (const sd::Tensor<float >& latents) override {
807- if (sd_version_is_flux2 (version)) {
807+ if (sd_version_uses_flux2_vae (version)) {
808808 int channel_dim = 2 ;
809809 auto [mean_tensor, std_tensor] = get_latents_mean_std (latents, channel_dim);
810810 return (latents * std_tensor) / scale_factor + mean_tensor;
@@ -813,7 +813,7 @@ struct AutoEncoderKL : public VAE {
813813 }
814814
815815 sd::Tensor<float > vae_to_diffusion_latents (const sd::Tensor<float >& latents) override {
816- if (sd_version_is_flux2 (version)) {
816+ if (sd_version_uses_flux2_vae (version)) {
817817 int channel_dim = 2 ;
818818 auto [mean_tensor, std_tensor] = get_latents_mean_std (latents, channel_dim);
819819 return ((latents - mean_tensor) * scale_factor) / std_tensor;
0 commit comments