Skip to content

Commit acc3bf1

Browse files
authored
refactor: optimize the VAE architecture (#1345)
1 parent 83eabd7 commit acc3bf1

8 files changed

Lines changed: 1415 additions & 1337 deletions

File tree

src/auto_encoder_kl.hpp

Lines changed: 930 additions & 0 deletions
Large diffs are not rendered by default.

src/ggml_extend.hpp

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,12 @@ __STATIC_INLINE__ void copy_ggml_tensor(struct ggml_tensor* dst, struct ggml_ten
377377
ggml_free(ctx);
378378
}
379379

380+
__STATIC_INLINE__ ggml_tensor* ggml_ext_dup_and_cpy_tensor(ggml_context* ctx, ggml_tensor* src) {
381+
ggml_tensor* dup = ggml_dup_tensor(ctx, src);
382+
copy_ggml_tensor(dup, src);
383+
return dup;
384+
}
385+
380386
__STATIC_INLINE__ float sigmoid(float x) {
381387
return 1 / (1.0f + expf(-x));
382388
}
@@ -637,7 +643,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_tensor_concat(struct ggml_context
637643
}
638644

639645
// convert values from [0, 1] to [-1, 1]
640-
__STATIC_INLINE__ void process_vae_input_tensor(struct ggml_tensor* src) {
646+
__STATIC_INLINE__ void scale_to_minus1_1(struct ggml_tensor* src) {
641647
int64_t nelements = ggml_nelements(src);
642648
float* data = (float*)src->data;
643649
for (int i = 0; i < nelements; i++) {
@@ -647,7 +653,7 @@ __STATIC_INLINE__ void process_vae_input_tensor(struct ggml_tensor* src) {
647653
}
648654

649655
// convert values from [-1, 1] to [0, 1]
650-
__STATIC_INLINE__ void process_vae_output_tensor(struct ggml_tensor* src) {
656+
__STATIC_INLINE__ void scale_to_0_1(struct ggml_tensor* src) {
651657
int64_t nelements = ggml_nelements(src);
652658
float* data = (float*)src->data;
653659
for (int i = 0; i < nelements; i++) {
@@ -834,7 +840,8 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
834840
const float tile_overlap_factor,
835841
const bool circular_x,
836842
const bool circular_y,
837-
on_tile_process on_processing) {
843+
on_tile_process on_processing,
844+
bool slient = false) {
838845
output = ggml_set_f32(output, 0);
839846

840847
int input_width = (int)input->ne[0];
@@ -864,8 +871,10 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
864871
float tile_overlap_factor_y;
865872
sd_tiling_calc_tiles(num_tiles_y, tile_overlap_factor_y, small_height, p_tile_size_y, tile_overlap_factor, circular_y);
866873

867-
LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y);
868-
LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor);
874+
if (!slient) {
875+
LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y);
876+
LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor);
877+
}
869878

870879
int tile_overlap_x = (int32_t)(p_tile_size_x * tile_overlap_factor_x);
871880
int non_tile_overlap_x = p_tile_size_x - tile_overlap_x;
@@ -896,7 +905,9 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
896905
params.mem_buffer = nullptr;
897906
params.no_alloc = false;
898907

899-
LOG_DEBUG("tile work buffer size: %.2f MB", params.mem_size / 1024.f / 1024.f);
908+
if (!slient) {
909+
LOG_DEBUG("tile work buffer size: %.2f MB", params.mem_size / 1024.f / 1024.f);
910+
}
900911

901912
// draft context
902913
struct ggml_context* tiles_ctx = ggml_init(params);
@@ -909,8 +920,10 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
909920
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size_x, input_tile_size_y, input->ne[2], input->ne[3]);
910921
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size_x, output_tile_size_y, output->ne[2], output->ne[3]);
911922
int num_tiles = num_tiles_x * num_tiles_y;
912-
LOG_DEBUG("processing %i tiles", num_tiles);
913-
pretty_progress(0, num_tiles, 0.0f);
923+
if (!slient) {
924+
LOG_DEBUG("processing %i tiles", num_tiles);
925+
pretty_progress(0, num_tiles, 0.0f);
926+
}
914927
int tile_count = 1;
915928
bool last_y = false, last_x = false;
916929
float last_time = 0.0f;
@@ -960,8 +973,10 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
960973
}
961974
last_x = false;
962975
}
963-
if (tile_count < num_tiles) {
964-
pretty_progress(num_tiles, num_tiles, last_time);
976+
if (!slient) {
977+
if (tile_count < num_tiles) {
978+
pretty_progress(num_tiles, num_tiles, last_time);
979+
}
965980
}
966981
ggml_free(tiles_ctx);
967982
}

src/model.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,10 +1104,12 @@ SDVersion ModelLoader::get_sd_version() {
11041104
tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) {
11051105
has_middle_block_1 = true;
11061106
}
1107-
if (tensor_storage.name.find("model.diffusion_model.output_blocks.3.1.transformer_blocks.1") != std::string::npos) {
1107+
if (tensor_storage.name.find("model.diffusion_model.output_blocks.3.1.transformer_blocks.1") != std::string::npos ||
1108+
tensor_storage.name.find("unet.up_blocks.1.attentions.0.transformer_blocks.1") != std::string::npos) {
11081109
has_output_block_311 = true;
11091110
}
1110-
if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos) {
1111+
if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos ||
1112+
tensor_storage.name.find("unet.up_blocks.2.attentions.1") != std::string::npos) {
11111113
has_output_block_71 = true;
11121114
}
11131115
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||

src/name_conversion.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1120,7 +1120,11 @@ std::string convert_tensor_name(std::string name, SDVersion version) {
11201120
for (const auto& prefix : first_stage_model_prefix_vec) {
11211121
if (starts_with(name, prefix)) {
11221122
name = convert_first_stage_model_name(name.substr(prefix.size()), prefix);
1123-
name = prefix + name;
1123+
if (version == VERSION_SDXS) {
1124+
name = "tae." + name;
1125+
} else {
1126+
name = prefix + name;
1127+
}
11241128
break;
11251129
}
11261130
}

0 commit comments

Comments
 (0)