Skip to content

Commit 5c243db

Browse files
authored
feat: add ernie image support (#1427)
1 parent c41c5de commit 5c243db

14 files changed

Lines changed: 699 additions & 20 deletions

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ API and command-line option may change frequently.***
5757
- [Z-Image](./docs/z_image.md)
5858
- [Ovis-Image](./docs/ovis_image.md)
5959
- [Anima](./docs/anima.md)
60+
- [ERNIE-Image](./docs/ernie_image.md)
6061
- Image Edit Models
6162
- [FLUX.1-Kontext-dev](./docs/kontext.md)
6263
- [Qwen Image Edit series](./docs/qwen_image_edit.md)
@@ -144,6 +145,7 @@ If you want to improve performance or reduce VRAM/RAM usage, please refer to [pe
144145
- [🔥Z-Image](./docs/z_image.md)
145146
- [Ovis-Image](./docs/ovis_image.md)
146147
- [Anima](./docs/anima.md)
148+
- [ERNIE-Image](./docs/ernie_image.md)
147149
- [LoRA](./docs/lora.md)
148150
- [LCM/LCM-LoRA](./docs/lcm.md)
149151
- [Using PhotoMaker to personalize image generation](./docs/photo_maker.md)

assets/ernie_image/example.png

595 KB
Loading
562 KB
Loading

docs/ernie_image.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# How to Use
2+
3+
You can run ERNIE-Image with stable-diffusion.cpp on GPUs with 4GB of VRAM — or even less.
4+
5+
## Download weights
6+
7+
- Download ERNIE-Image-Turbo
8+
- safetensors: https://huggingface.co/Comfy-Org/ERNIE-Image/tree/main/diffusion_models
9+
- gguf: https://huggingface.co/unsloth/ERNIE-Image-Turbo-GGUF/tree/main
10+
- Download ERNIE-Image
11+
- safetensors: https://huggingface.co/Comfy-Org/ERNIE-Image/tree/main/diffusion_models
12+
- gguf: https://huggingface.co/unsloth/ERNIE-Image-GGUF/tree/main
13+
- Download vae
14+
- safetensors: https://huggingface.co/Comfy-Org/ERNIE-Image/tree/main/vae
15+
- Download ministral 3b
16+
- safetensors: https://huggingface.co/Comfy-Org/ERNIE-Image/tree/main/text_encoders
17+
- gguf: https://huggingface.co/unsloth/Ministral-3-3B-Instruct-2512-GGUF/tree/main
18+
19+
## Examples
20+
21+
### ERNIE-Image-Turbo
22+
23+
```
24+
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\ernie-image-turbo.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\ministral-3-3b.safetensors -p "a lovely cat" --cfg-scale 1.0 --steps 8 -v --offload-to-cpu --diffusion-fa
25+
```
26+
27+
<img width="256" alt="ERNIE-Image Turbo example" src="../assets/ernie_image/turbo_example.png" />
28+
29+
### ERNIE-Image
30+
31+
```
32+
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\ernie-image-UD-Q4_K_M.gguf --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\ministral-3-3b.safetensors -p "a lovely cat" --cfg-scale 5.0 -v --offload-to-cpu --diffusion-fa
33+
```
34+
35+
<img width="256" alt="ERNIE-Image example" src="../assets/ernie_image/example.png" />

src/auto_encoder_kl.hpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ class AutoEncoderKLModel : public GGMLBlock {
533533
const std::string& prefix = "")
534534
: version(version), decode_only(decode_only), use_video_decoder(use_video_decoder) {
535535
if (sd_version_is_dit(version)) {
536-
if (sd_version_is_flux2(version)) {
536+
if (sd_version_uses_flux2_vae(version)) {
537537
dd_config.z_channels = 32;
538538
embed_dim = 32;
539539
} else {
@@ -578,7 +578,7 @@ class AutoEncoderKLModel : public GGMLBlock {
578578

579579
ggml_tensor* decode(GGMLRunnerContext* ctx, ggml_tensor* z) {
580580
// z: [N, z_channels, h, w]
581-
if (sd_version_is_flux2(version)) {
581+
if (sd_version_uses_flux2_vae(version)) {
582582
// [N, C*p*p, h, w] -> [N, C, h*p, w*p]
583583
int64_t p = 2;
584584

@@ -617,7 +617,7 @@ class AutoEncoderKLModel : public GGMLBlock {
617617
auto quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["quant_conv"]);
618618
z = quant_conv->forward(ctx, z); // [N, 2*embed_dim, h/8, w/8]
619619
}
620-
if (sd_version_is_flux2(version)) {
620+
if (sd_version_uses_flux2_vae(version)) {
621621
z = ggml_ext_chunk(ctx->ggml_ctx, z, 2, 2)[0];
622622

623623
// [N, C, H, W] -> [N, C*p*p, H/p, W/p]
@@ -640,7 +640,7 @@ class AutoEncoderKLModel : public GGMLBlock {
640640

641641
int get_encoder_output_channels() {
642642
int factor = dd_config.double_z ? 2 : 1;
643-
if (sd_version_is_flux2(version)) {
643+
if (sd_version_uses_flux2_vae(version)) {
644644
return dd_config.z_channels * 4;
645645
}
646646
return dd_config.z_channels * factor;
@@ -673,7 +673,7 @@ struct AutoEncoderKL : public VAE {
673673
} else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) {
674674
scale_factor = 0.3611f;
675675
shift_factor = 0.1159f;
676-
} else if (sd_version_is_flux2(version)) {
676+
} else if (sd_version_uses_flux2_vae(version)) {
677677
scale_factor = 1.0f;
678678
shift_factor = 0.f;
679679
}
@@ -747,7 +747,7 @@ struct AutoEncoderKL : public VAE {
747747
}
748748

749749
sd::Tensor<float> vae_output_to_latents(const sd::Tensor<float>& vae_output, std::shared_ptr<RNG> rng) override {
750-
if (sd_version_is_flux2(version)) {
750+
if (sd_version_uses_flux2_vae(version)) {
751751
return vae_output;
752752
} else if (version == VERSION_SD1_PIX2PIX) {
753753
return sd::ops::chunk(vae_output, 2, 2)[0];
@@ -758,7 +758,7 @@ struct AutoEncoderKL : public VAE {
758758

759759
std::pair<sd::Tensor<float>, sd::Tensor<float>> get_latents_mean_std(const sd::Tensor<float>& latents, int channel_dim) {
760760
GGML_ASSERT(channel_dim >= 0 && static_cast<size_t>(channel_dim) < static_cast<size_t>(latents.dim()));
761-
if (sd_version_is_flux2(version)) {
761+
if (sd_version_uses_flux2_vae(version)) {
762762
GGML_ASSERT(latents.shape()[channel_dim] == 128);
763763
std::vector<int64_t> stats_shape(static_cast<size_t>(latents.dim()), 1);
764764
stats_shape[static_cast<size_t>(channel_dim)] = latents.shape()[channel_dim];
@@ -804,7 +804,7 @@ struct AutoEncoderKL : public VAE {
804804
}
805805

806806
sd::Tensor<float> diffusion_to_vae_latents(const sd::Tensor<float>& latents) override {
807-
if (sd_version_is_flux2(version)) {
807+
if (sd_version_uses_flux2_vae(version)) {
808808
int channel_dim = 2;
809809
auto [mean_tensor, std_tensor] = get_latents_mean_std(latents, channel_dim);
810810
return (latents * std_tensor) / scale_factor + mean_tensor;
@@ -813,7 +813,7 @@ struct AutoEncoderKL : public VAE {
813813
}
814814

815815
sd::Tensor<float> vae_to_diffusion_latents(const sd::Tensor<float>& latents) override {
816-
if (sd_version_is_flux2(version)) {
816+
if (sd_version_uses_flux2_vae(version)) {
817817
int channel_dim = 2;
818818
auto [mean_tensor, std_tensor] = get_latents_mean_std(latents, channel_dim);
819819
return ((latents - mean_tensor) * scale_factor) / std_tensor;

src/conditioner.hpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1621,10 +1621,12 @@ struct LLMEmbedder : public Conditioner {
16211621
LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL;
16221622
if (version == VERSION_FLUX2) {
16231623
arch = LLM::LLMArch::MISTRAL_SMALL_3_2;
1624+
} else if (sd_version_is_ernie_image(version)) {
1625+
arch = LLM::LLMArch::MINISTRAL_3_3B;
16241626
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) {
16251627
arch = LLM::LLMArch::QWEN3;
16261628
}
1627-
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
1629+
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2 || arch == LLM::LLMArch::MINISTRAL_3_3B) {
16281630
tokenizer = std::make_shared<MistralTokenizer>();
16291631
} else {
16301632
tokenizer = std::make_shared<Qwen2Tokenizer>();
@@ -1867,6 +1869,13 @@ struct LLMEmbedder : public Conditioner {
18671869
prompt_attn_range.second = static_cast<int>(prompt.size());
18681870

18691871
prompt += "[/INST]";
1872+
} else if (sd_version_is_ernie_image(version)) {
1873+
prompt_template_encode_start_idx = 0;
1874+
out_layers = {25}; // -2
1875+
1876+
prompt_attn_range.first = 0;
1877+
prompt += conditioner_params.text;
1878+
prompt_attn_range.second = static_cast<int>(prompt.size());
18701879
} else if (sd_version_is_z_image(version)) {
18711880
prompt_template_encode_start_idx = 0;
18721881
out_layers = {35}; // -2

src/diffusion_model.hpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <optional>
55
#include "anima.hpp"
6+
#include "ernie_image.hpp"
67
#include "flux.hpp"
78
#include "mmdit.hpp"
89
#include "qwen_image.hpp"
@@ -516,4 +517,66 @@ struct ZImageModel : public DiffusionModel {
516517
}
517518
};
518519

520+
struct ErnieImageModel : public DiffusionModel {
521+
std::string prefix;
522+
ErnieImage::ErnieImageRunner ernie_image;
523+
524+
ErnieImageModel(ggml_backend_t backend,
525+
bool offload_params_to_cpu,
526+
const String2TensorStorage& tensor_storage_map = {},
527+
const std::string prefix = "model.diffusion_model")
528+
: prefix(prefix), ernie_image(backend, offload_params_to_cpu, tensor_storage_map, prefix) {
529+
}
530+
531+
std::string get_desc() override {
532+
return ernie_image.get_desc();
533+
}
534+
535+
void alloc_params_buffer() override {
536+
ernie_image.alloc_params_buffer();
537+
}
538+
539+
void free_params_buffer() override {
540+
ernie_image.free_params_buffer();
541+
}
542+
543+
void free_compute_buffer() override {
544+
ernie_image.free_compute_buffer();
545+
}
546+
547+
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) override {
548+
ernie_image.get_param_tensors(tensors, prefix);
549+
}
550+
551+
size_t get_params_buffer_size() override {
552+
return ernie_image.get_params_buffer_size();
553+
}
554+
555+
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
556+
ernie_image.set_weight_adapter(adapter);
557+
}
558+
559+
int64_t get_adm_in_channels() override {
560+
return 768;
561+
}
562+
563+
void set_flash_attention_enabled(bool enabled) {
564+
ernie_image.set_flash_attention_enabled(enabled);
565+
}
566+
567+
void set_circular_axes(bool circular_x, bool circular_y) override {
568+
ernie_image.set_circular_axes(circular_x, circular_y);
569+
}
570+
571+
sd::Tensor<float> compute(int n_threads,
572+
const DiffusionParams& diffusion_params) override {
573+
GGML_ASSERT(diffusion_params.x != nullptr);
574+
GGML_ASSERT(diffusion_params.timesteps != nullptr);
575+
return ernie_image.compute(n_threads,
576+
*diffusion_params.x,
577+
*diffusion_params.timesteps,
578+
tensor_or_empty(diffusion_params.context));
579+
}
580+
};
581+
519582
#endif

0 commit comments

Comments
 (0)