Skip to content

Commit 99c1de3

Browse files
wbrunaleejet
andauthored
feat: ancestral sampler implementations for flow models (#1374)
* feat: add support for the eta parameter to ancestral samplers * feat: Euler Ancestral sampler implementation for flow models * refine flow ancestral sampling and normalize eta defaults --------- Co-authored-by: leejet <leejet714@gmail.com>
1 parent 09b12d5 commit 99c1de3

5 files changed

Lines changed: 117 additions & 17 deletions

File tree

examples/cli/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,15 @@ Generation Options:
114114
medium
115115
--skip-layer-start <float> SLG enabling point (default: 0.01)
116116
--skip-layer-end <float> SLG disabling point (default: 0.2)
117-
--eta <float> eta in DDIM, only for DDIM and TCD (default: 0)
117+
--eta <float> noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a)
118118
--flow-shift <float> shift value for Flow models like SD3.x or WAN (default: auto)
119119
--high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0)
120120
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)
121121
--high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input (default: 3.5)
122122
--high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0)
123123
--high-noise-skip-layer-start <float> (high noise) SLG enabling point (default: 0.01)
124124
--high-noise-skip-layer-end <float> (high noise) SLG disabling point (default: 0.2)
125-
--high-noise-eta <float> (high noise) eta in DDIM, only for DDIM and TCD (default: 0)
125+
--high-noise-eta <float> (high noise) noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a)
126126
--strength <float> strength for noising/unnoising (default: 0.75)
127127
--pm-style-strength <float>
128128
--control-strength <float> strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image

examples/common/common.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,7 +1131,7 @@ struct SDGenerationParams {
11311131
&sample_params.guidance.slg.layer_end},
11321132
{"",
11331133
"--eta",
1134-
"eta in DDIM, only for DDIM and TCD (default: 0)",
1134+
"noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a)",
11351135
&sample_params.eta},
11361136
{"",
11371137
"--flow-shift",
@@ -1163,7 +1163,7 @@ struct SDGenerationParams {
11631163
&high_noise_sample_params.guidance.slg.layer_end},
11641164
{"",
11651165
"--high-noise-eta",
1166-
"(high noise) eta in DDIM, only for DDIM and TCD (default: 0)",
1166+
"(high noise) noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a)",
11671167
&high_noise_sample_params.eta},
11681168
{"",
11691169
"--strength",

examples/server/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,15 +189,15 @@ Default Generation Options:
189189
medium
190190
--skip-layer-start <float> SLG enabling point (default: 0.01)
191191
--skip-layer-end <float> SLG disabling point (default: 0.2)
192-
--eta <float> eta in DDIM, only for DDIM and TCD (default: 0)
192+
--eta <float> noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a)
193193
--flow-shift <float> shift value for Flow models like SD3.x or WAN (default: auto)
194194
--high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0)
195195
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)
196196
--high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input (default: 3.5)
197197
--high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0)
198198
--high-noise-skip-layer-start <float> (high noise) SLG enabling point (default: 0.01)
199199
--high-noise-skip-layer-end <float> (high noise) SLG disabling point (default: 0.2)
200-
--high-noise-eta <float> (high noise) eta in DDIM, only for DDIM and TCD (default: 0)
200+
--high-noise-eta <float> (high noise) noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a)
201201
--strength <float> strength for noising/unnoising (default: 0.75)
202202
--pm-style-strength <float>
203203
--control-strength <float> strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image

src/denoiser.hpp

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -786,10 +786,43 @@ static std::pair<float, float> get_ancestral_step(float sigma_from,
786786
return {sigma_down, sigma_up};
787787
}
788788

789+
static std::tuple<float, float, float> get_ancestral_step_flow(float sigma_from,
790+
float sigma_to,
791+
float eta = 1.0f) {
792+
float sigma_down = sigma_to;
793+
float sigma_up = 0.0f;
794+
float alpha_scale = 1.0f;
795+
796+
if (eta <= 0.0f || sigma_from <= 0.0f || sigma_to <= 0.0f) {
797+
return {sigma_down, sigma_up, alpha_scale};
798+
}
799+
800+
// Flow Euler ancestral sampling becomes numerically unstable for eta > 1, so
801+
// clamp to the valid maximum-noise regime instead of letting NaNs propagate.
802+
eta = std::min(eta, 1.0f);
803+
804+
float sigma_ratio = sigma_to / sigma_from;
805+
sigma_down = sigma_to * (1.0f + (sigma_ratio - 1.0f) * eta);
806+
sigma_down = std::max(0.0f, std::min(sigma_to, sigma_down));
807+
808+
float denom = 1.0f - sigma_down;
809+
if (denom <= 0.0f) {
810+
return {sigma_to, sigma_up, alpha_scale};
811+
}
812+
813+
alpha_scale = (1.0f - sigma_to) / denom;
814+
815+
float term = (sigma_down / sigma_to) * alpha_scale;
816+
term = std::max(-1.0f, std::min(1.0f, term));
817+
sigma_up = sigma_to * std::sqrt(std::max(1.0f - term * term, 0.0f));
818+
return {sigma_down, sigma_up, alpha_scale};
819+
}
820+
789821
static sd::Tensor<float> sample_euler_ancestral(denoise_cb_t model,
790822
sd::Tensor<float> x,
791823
const std::vector<float>& sigmas,
792-
std::shared_ptr<RNG> rng) {
824+
std::shared_ptr<RNG> rng,
825+
float eta) {
793826
int steps = static_cast<int>(sigmas.size()) - 1;
794827
for (int i = 0; i < steps; i++) {
795828
float sigma = sigmas[i];
@@ -799,7 +832,7 @@ static sd::Tensor<float> sample_euler_ancestral(denoise_cb_t model,
799832
}
800833
sd::Tensor<float> denoised = std::move(denoised_opt);
801834
sd::Tensor<float> d = (x - denoised) / sigma;
802-
auto [sigma_down, sigma_up] = get_ancestral_step(sigmas[i], sigmas[i + 1]);
835+
auto [sigma_down, sigma_up] = get_ancestral_step(sigmas[i], sigmas[i + 1], eta);
803836
x += d * (sigma_down - sigmas[i]);
804837
if (sigmas[i + 1] > 0) {
805838
x += sd::Tensor<float>::randn_like(x, rng) * sigma_up;
@@ -808,6 +841,30 @@ static sd::Tensor<float> sample_euler_ancestral(denoise_cb_t model,
808841
return x;
809842
}
810843

844+
static sd::Tensor<float> sample_euler_flow(denoise_cb_t model,
845+
sd::Tensor<float> x,
846+
const std::vector<float>& sigmas,
847+
std::shared_ptr<RNG> rng,
848+
float eta) {
849+
int steps = static_cast<int>(sigmas.size()) - 1;
850+
for (int i = 0; i < steps; i++) {
851+
float sigma = sigmas[i];
852+
auto denoised_opt = model(x, sigma, i + 1);
853+
if (denoised_opt.empty()) {
854+
return {};
855+
}
856+
sd::Tensor<float> denoised = std::move(denoised_opt);
857+
auto [sigma_down, sigma_up, alpha_scale] = get_ancestral_step_flow(sigma, sigmas[i + 1], eta);
858+
float sigma_ratio = sigma_down / sigma;
859+
x = sigma_ratio * x + (1.0f - sigma_ratio) * denoised;
860+
861+
if (sigma_up > 0.0f) {
862+
x = alpha_scale * x + sd::Tensor<float>::randn_like(x, rng) * sigma_up;
863+
}
864+
}
865+
return x;
866+
}
867+
811868
static sd::Tensor<float> sample_euler(denoise_cb_t model,
812869
sd::Tensor<float> x,
813870
const std::vector<float>& sigmas) {
@@ -885,7 +942,8 @@ static sd::Tensor<float> sample_dpm2(denoise_cb_t model,
885942
static sd::Tensor<float> sample_dpmpp_2s_ancestral(denoise_cb_t model,
886943
sd::Tensor<float> x,
887944
const std::vector<float>& sigmas,
888-
std::shared_ptr<RNG> rng) {
945+
std::shared_ptr<RNG> rng,
946+
float eta) {
889947
auto t_fn = [](float sigma) -> float { return -log(sigma); };
890948
auto sigma_fn = [](float t) -> float { return exp(-t); };
891949

@@ -896,7 +954,7 @@ static sd::Tensor<float> sample_dpmpp_2s_ancestral(denoise_cb_t model,
896954
return {};
897955
}
898956
sd::Tensor<float> denoised = std::move(denoised_opt);
899-
auto [sigma_down, sigma_up] = get_ancestral_step(sigmas[i], sigmas[i + 1]);
957+
auto [sigma_down, sigma_up] = get_ancestral_step(sigmas[i], sigmas[i + 1], eta);
900958

901959
if (sigma_down == 0) {
902960
x = denoised;
@@ -1368,18 +1426,22 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
13681426
sd::Tensor<float> x,
13691427
std::vector<float> sigmas,
13701428
std::shared_ptr<RNG> rng,
1371-
float eta) {
1429+
float eta,
1430+
bool is_flow_denoiser) {
13721431
switch (method) {
13731432
case EULER_A_SAMPLE_METHOD:
1374-
return sample_euler_ancestral(model, std::move(x), sigmas, rng);
1433+
if (is_flow_denoiser)
1434+
return sample_euler_flow(model, std::move(x), sigmas, rng, eta);
1435+
else
1436+
return sample_euler_ancestral(model, std::move(x), sigmas, rng, eta);
13751437
case EULER_SAMPLE_METHOD:
13761438
return sample_euler(model, std::move(x), sigmas);
13771439
case HEUN_SAMPLE_METHOD:
13781440
return sample_heun(model, std::move(x), sigmas);
13791441
case DPM2_SAMPLE_METHOD:
13801442
return sample_dpm2(model, std::move(x), sigmas);
13811443
case DPMPP2S_A_SAMPLE_METHOD:
1382-
return sample_dpmpp_2s_ancestral(model, std::move(x), sigmas, rng);
1444+
return sample_dpmpp_2s_ancestral(model, std::move(x), sigmas, rng, eta);
13831445
case DPMPP2M_SAMPLE_METHOD:
13841446
return sample_dpmpp_2m(model, std::move(x), sigmas);
13851447
case DPMPP2Mv2_SAMPLE_METHOD:

src/stable-diffusion.cpp

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,6 +1593,7 @@ class StableDiffusionGGML {
15931593
float eta,
15941594
int shifted_timestep,
15951595
sample_method_t method,
1596+
bool is_flow_denoiser,
15961597
const std::vector<float>& sigmas,
15971598
int start_merge_step,
15981599
const std::vector<sd::Tensor<float>>& ref_latents,
@@ -1791,7 +1792,7 @@ class StableDiffusionGGML {
17911792
return denoised;
17921793
};
17931794

1794-
auto x0_opt = sample_k_diffusion(method, denoise, x_t, sigmas, sampler_rng, eta);
1795+
auto x0_opt = sample_k_diffusion(method, denoise, x_t, sigmas, sampler_rng, eta, is_flow_denoiser);
17951796
if (x0_opt.empty()) {
17961797
LOG_ERROR("Diffusion model sampling failed");
17971798
if (control_net) {
@@ -1909,6 +1910,11 @@ class StableDiffusionGGML {
19091910
flow_denoiser->set_shift(flow_shift);
19101911
}
19111912
}
1913+
1914+
bool is_flow_denoiser() {
1915+
auto flow_denoiser = std::dynamic_pointer_cast<DiscreteFlowDenoiser>(denoiser);
1916+
return !!flow_denoiser;
1917+
}
19121918
};
19131919

19141920
/*================================================= SD API ==================================================*/
@@ -2225,6 +2231,7 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
22252231
sample_params->scheduler = SCHEDULER_COUNT;
22262232
sample_params->sample_method = SAMPLE_METHOD_COUNT;
22272233
sample_params->sample_steps = 20;
2234+
sample_params->eta = INFINITY;
22282235
sample_params->custom_sigmas = nullptr;
22292236
sample_params->custom_sigmas_count = 0;
22302237
sample_params->flow_shift = INFINITY;
@@ -2438,6 +2445,26 @@ static scheduler_t resolve_scheduler(sd_ctx_t* sd_ctx,
24382445
return scheduler;
24392446
}
24402447

2448+
static float resolve_eta(sd_ctx_t* sd_ctx,
2449+
float eta,
2450+
enum sample_method_t sample_method) {
2451+
if (eta == INFINITY) {
2452+
switch (sample_method) {
2453+
case DDIM_TRAILING_SAMPLE_METHOD:
2454+
case TCD_SAMPLE_METHOD:
2455+
case RES_MULTISTEP_SAMPLE_METHOD:
2456+
case RES_2S_SAMPLE_METHOD:
2457+
return 0.0f;
2458+
case EULER_A_SAMPLE_METHOD:
2459+
case DPMPP2S_A_SAMPLE_METHOD:
2460+
return 1.0f;
2461+
default:;
2462+
}
2463+
return 0.0f;
2464+
}
2465+
return eta;
2466+
}
2467+
24412468
struct GenerationRequest {
24422469
std::string prompt;
24432470
std::string negative_prompt;
@@ -2586,6 +2613,8 @@ struct GenerationRequest {
25862613
struct SamplePlan {
25872614
enum sample_method_t sample_method = SAMPLE_METHOD_COUNT;
25882615
enum sample_method_t high_noise_sample_method = SAMPLE_METHOD_COUNT;
2616+
float eta = 0.f;
2617+
float high_noise_eta = 0.f;
25892618
int sample_steps = 0;
25902619
int high_noise_sample_steps = 0;
25912620
int total_steps = 0;
@@ -2597,6 +2626,7 @@ struct SamplePlan {
25972626
const sd_img_gen_params_t* sd_img_gen_params,
25982627
const GenerationRequest& request) {
25992628
sample_method = sd_img_gen_params->sample_params.sample_method;
2629+
eta = sd_img_gen_params->sample_params.eta;
26002630
sample_steps = sd_img_gen_params->sample_params.sample_steps;
26012631
resolve(sd_ctx, &request, &sd_img_gen_params->sample_params);
26022632
}
@@ -2605,10 +2635,12 @@ struct SamplePlan {
26052635
const sd_vid_gen_params_t* sd_vid_gen_params,
26062636
const GenerationRequest& request) {
26072637
sample_method = sd_vid_gen_params->sample_params.sample_method;
2638+
eta = sd_vid_gen_params->sample_params.eta;
26082639
sample_steps = sd_vid_gen_params->sample_params.sample_steps;
26092640
if (sd_ctx->sd->high_noise_diffusion_model) {
26102641
high_noise_sample_steps = sd_vid_gen_params->high_noise_sample_params.sample_steps;
26112642
high_noise_sample_method = sd_vid_gen_params->high_noise_sample_params.sample_method;
2643+
high_noise_eta = sd_vid_gen_params->high_noise_sample_params.eta;
26122644
}
26132645
moe_boundary = sd_vid_gen_params->moe_boundary;
26142646
resolve(sd_ctx, &request, &sd_vid_gen_params->sample_params);
@@ -2644,6 +2676,8 @@ struct SamplePlan {
26442676
sd_ctx->sd->version);
26452677
}
26462678

2679+
eta = resolve_eta(sd_ctx, eta, sample_method);
2680+
26472681
if (high_noise_sample_steps < 0) {
26482682
for (size_t i = 0; i < sigmas.size(); ++i) {
26492683
if (sigmas[i] < moe_boundary) {
@@ -2658,6 +2692,7 @@ struct SamplePlan {
26582692
if (high_noise_sample_steps > 0) {
26592693
high_noise_sample_method = resolve_sample_method(sd_ctx,
26602694
high_noise_sample_method);
2695+
high_noise_eta = resolve_eta(sd_ctx, high_noise_eta, high_noise_sample_method);
26612696
LOG_INFO("sampling(high noise) using %s method", sampling_methods_str[high_noise_sample_method]);
26622697
}
26632698

@@ -3123,9 +3158,10 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
31233158
latents.control_image,
31243159
request.control_strength,
31253160
request.guidance,
3126-
request.eta,
3161+
plan.eta,
31273162
request.shifted_timestep,
31283163
plan.sample_method,
3164+
sd_ctx->sd->is_flow_denoiser(),
31293165
plan.sigmas,
31303166
plan.start_merge_step,
31313167
latents.ref_latents,
@@ -3482,9 +3518,10 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
34823518
sd::Tensor<float>(),
34833519
0.f,
34843520
request.high_noise_guidance,
3485-
sd_vid_gen_params->high_noise_sample_params.eta,
3521+
plan.high_noise_eta,
34863522
request.shifted_timestep,
34873523
plan.high_noise_sample_method,
3524+
sd_ctx->sd->is_flow_denoiser(),
34883525
high_noise_sigmas,
34893526
-1,
34903527
std::vector<sd::Tensor<float>>{},
@@ -3523,9 +3560,10 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
35233560
sd::Tensor<float>(),
35243561
0.f,
35253562
sd_vid_gen_params->sample_params.guidance,
3526-
sd_vid_gen_params->sample_params.eta,
3563+
plan.eta,
35273564
sd_vid_gen_params->sample_params.shifted_timestep,
35283565
plan.sample_method,
3566+
sd_ctx->sd->is_flow_denoiser(),
35293567
plan.sigmas,
35303568
-1,
35313569
std::vector<sd::Tensor<float>>{},

0 commit comments

Comments
 (0)