diff --git a/csrc/engine/compiler/paged_compiler.cpp b/csrc/engine/compiler/paged_compiler.cpp index de6ec5d14..392ca5e69 100644 --- a/csrc/engine/compiler/paged_compiler.cpp +++ b/csrc/engine/compiler/paged_compiler.cpp @@ -109,7 +109,11 @@ PagedCompiler::Compiled PagedCompiler::get_compiled(const InfinilmModel::Input & graph_input.slot_mapping.value()->copy_from(input.slot_mapping.value()); auto graph = std::get<0>(result->second.compiled); - auto shared_output = std::shared_ptr(new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()}); + // Reuse the GraphTensor output captured at compile time. + // Do not call resume_from_blob_() on workspace-backed logits: + // that registers a second deleter on the same GPU block and + // triggers double free in PinnableBlockAllocator. + auto shared_output = std::get<1>(result->second.compiled); return std::make_tuple(graph, shared_output); } diff --git a/csrc/engine/compiler/static_batching_compiler.cpp b/csrc/engine/compiler/static_batching_compiler.cpp index dcd7f7143..637b3e791 100644 --- a/csrc/engine/compiler/static_batching_compiler.cpp +++ b/csrc/engine/compiler/static_batching_compiler.cpp @@ -56,7 +56,11 @@ StaticBatchingCompiler::Compiled StaticBatchingCompiler::get_compiled( graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value()); auto graph = std::get<0>(result->second.compiled); - auto shared_output = std::shared_ptr(new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()}); + // Reuse the GraphTensor output captured at compile time. + // Do not call resume_from_blob_() on workspace-backed logits: + // that registers a second deleter on the same GPU block and + // triggers double free in PinnableBlockAllocator. + auto shared_output = std::get<1>(result->second.compiled); return std::make_tuple(graph, shared_output); } } else { diff --git a/csrc/engine/infer_engine.cpp b/csrc/engine/infer_engine.cpp index 2e34c6cca..68c1d8ad7 100644 --- a/csrc/engine/infer_engine.cpp +++ b/csrc/engine/infer_engine.cpp @@ -14,7 +14,8 @@ InferEngine::InferEngine( const cache::CacheConfig *cache_config, bool enable_graph_compiling, backends::AttentionBackend attention_backend, - std::optional kv_cache_dtype) // Changed parameter + std::optional kv_cache_dtype, // Changed parameter + size_t max_num_batched_tokens) : communication_group_(distributed_config, device_type), attention_backend_(attention_backend) { if (cache_config != nullptr) { cache_config_ = cache_config->unique_copy(); @@ -22,7 +23,7 @@ InferEngine::InferEngine( // Load model config if model_path is provided, model_path must be valid, and config.json exists this->model_config_ = infinilm::config::ConfigFactory::createConfig(config_str); - auto infinilm_config = std::make_shared(attention_backend, this->model_config_); + auto infinilm_config = std::make_shared(attention_backend, this->model_config_, max_num_batched_tokens); // Only support offline int8 kv cache quantization in this version if (kv_cache_dtype.has_value()) { diff --git a/csrc/engine/infer_engine.hpp b/csrc/engine/infer_engine.hpp index 4bafffa2c..30174bddc 100644 --- a/csrc/engine/infer_engine.hpp +++ b/csrc/engine/infer_engine.hpp @@ -27,7 +27,8 @@ class InferEngine { const cache::CacheConfig *cache_config = nullptr, bool enable_graph_compiling = false, backends::AttentionBackend attention_backend = backends::AttentionBackend::Default, - std::optional kv_cache_dtype = std::nullopt); + std::optional kv_cache_dtype = std::nullopt, + size_t max_num_batched_tokens = 2048); // Load a parameter to all workers (each can extract its shard inside RankWorker) void load_param(const std::string &name, const infinicore::Tensor ¶m); diff --git a/csrc/global_state/forward_context.hpp b/csrc/global_state/forward_context.hpp index a7ab6f862..619ed692f 100644 --- a/csrc/global_state/forward_context.hpp +++ b/csrc/global_state/forward_context.hpp @@ -1,6 +1,7 @@ #pragma once #include "../models/infinilm_model.hpp" +#include namespace infinilm::global_state { @@ -43,6 +44,9 @@ struct AttentionMetadata { struct ForwardContext { AttentionMetadata attn_metadata; std::vector kv_cache_vec; + + // preallocated workspace for some modules + std::unordered_map preallocated_workspace; }; void initialize_forward_context(ForwardContext &forward_context); diff --git a/csrc/global_state/infinilm_config.hpp b/csrc/global_state/infinilm_config.hpp index 9b80706ca..84f245d4c 100644 --- a/csrc/global_state/infinilm_config.hpp +++ b/csrc/global_state/infinilm_config.hpp @@ -14,13 +14,24 @@ struct InfinilmConfig { public: InfinilmConfig() = default; InfinilmConfig(const infinilm::backends::AttentionBackend &backend, - const std::shared_ptr &model_config) + const std::shared_ptr &model_config, + size_t max_num_batched_tokens) : attention_backend(backend), - model_config(model_config) {} + model_config(model_config), + max_num_batched_tokens(max_num_batched_tokens) { + + if (max_num_batched_tokens > 0) { + const size_t max_position_embeddings = model_config->get("max_position_embeddings"); + ASSERT(max_num_batched_tokens >= 512 && max_num_batched_tokens <= max_position_embeddings); + enable_preallocated_workspace = true; + } + } public: infinilm::backends::AttentionBackend attention_backend; std::shared_ptr model_config; + size_t max_num_batched_tokens = 0; + bool enable_preallocated_workspace = false; }; /** diff --git a/csrc/layers/attention/attention.cpp b/csrc/layers/attention/attention.cpp index 1b87f6fbc..cb0bb0b32 100644 --- a/csrc/layers/attention/attention.cpp +++ b/csrc/layers/attention/attention.cpp @@ -1,17 +1,21 @@ #include "attention.hpp" +#include "../../global_state/global_state.hpp" #include "../../utils.hpp" #include "../rotary_embedding/rotary_embedding.hpp" +#include +#include namespace infinilm::layers::attention { Attention::Attention(std::shared_ptr model_config, size_t layer_idx, - const infinicore::Device &device) { + const infinicore::Device &device) + : device_(device), + dtype_(model_config->get_dtype()) { layer_idx_ = layer_idx; hidden_size_ = model_config->get("hidden_size"); head_dim_ = model_config->get("head_dim"); - const auto &dtype{model_config->get_dtype()}; size_t total_num_heads = model_config->get("num_attention_heads"); size_t total_num_kv_heads = model_config->get("num_key_value_heads"); bool use_bias = model_config->get_or("attention_bias", true); @@ -31,18 +35,24 @@ Attention::Attention(std::shared_ptr model_config qkv_proj_ = std::make_shared( hidden_size_, head_dim_, total_num_heads, total_num_kv_heads, "q_proj", "k_proj", "v_proj", register_fn, - quantization_method, use_bias, dtype, device, rank_info); + quantization_method, use_bias, dtype_, device_, rank_info); o_proj_ = this->register_module( "o_proj", total_num_heads * head_dim_, hidden_size_, quantization_method, - use_output_bias, dtype, device, tp_rank, tp_size, rank_info.comm); + use_output_bias, dtype_, device_, tp_rank, tp_size, rank_info.comm); - rotary_emb_ = infinilm::layers::rotary_embedding::get_rope(model_config, device); + rotary_emb_ = infinilm::layers::rotary_embedding::get_rope(model_config, device_); float scaling = 1.0f / std::sqrt(static_cast(head_dim_)); attn_ = std::make_shared(num_attention_heads_, head_dim_, scaling, num_key_value_heads_, layer_idx_, - kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_); + kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_, device_); - init_kv_cache_quant_params(register_fn, device, kv_cache_k_scale_, kv_cache_v_scale_); + init_kv_cache_quant_params(register_fn, device_, kv_cache_k_scale_, kv_cache_v_scale_); + + rank_qkv_output_size_ = qkv_proj_->out_features() / static_cast(tp_size); + enable_preallocated_workspace_ = infinilm::global_state::get_infinilm_config().enable_preallocated_workspace; + if (enable_preallocated_workspace_) { + this->_initialize_preallocated_workspace(); + } } infinicore::Tensor Attention::forward(const infinicore::Tensor &positions, @@ -62,7 +72,15 @@ infinicore::Tensor Attention::forward_static_(const infinicore::Tensor &position size_t seq_len = shape[1]; // 1. Project Q, K, V - auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); + infinicore::Tensor q; + infinicore::Tensor k; + infinicore::Tensor v; + if (enable_preallocated_workspace_) { + auto qkv_output = max_qkv_output_->narrow({{0, 0, batch_size * seq_len}})->view({batch_size, seq_len, rank_qkv_output_size_}); + std::tie(q, k, v) = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable); + } else { + std::tie(q, k, v) = qkv_proj_->forward_split(hidden_states_mutable); + } // 2. Reshape for multi-head attention auto q_reshaped = q->view({batch_size, seq_len, num_attention_heads_, head_dim_}); @@ -89,9 +107,13 @@ infinicore::Tensor Attention::forward_static_(const infinicore::Tensor &position // 5. Attn Backend calculate auto attn_output = attn_->forward(q_rope, k_reshaped, v_reshaped); - // 7. Project output - auto output = o_proj_->forward(attn_output); - return output; + // 6. Project output + if (enable_preallocated_workspace_) { + auto o_output = max_o_output_->narrow({{0, 0, batch_size * seq_len}})->view({batch_size, seq_len, hidden_size_}); + o_proj_->forward_(o_output, attn_output); + return o_output; + } + return o_proj_->forward(attn_output); } infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &position_ids, @@ -106,7 +128,15 @@ infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &position_ ASSERT_EQ(batch_size, 1); // 1. Project Q, K, V - auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); + infinicore::Tensor q; + infinicore::Tensor k; + infinicore::Tensor v; + if (enable_preallocated_workspace_) { + auto qkv_output = max_qkv_output_->narrow({{0, 0, seq_len}})->view({1, seq_len, rank_qkv_output_size_}); + std::tie(q, k, v) = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable); + } else { + std::tie(q, k, v) = qkv_proj_->forward_split(hidden_states_mutable); + } // 2. Reshape for multi-head attention auto q_reshaped = q->view({seq_len, num_attention_heads_, head_dim_}); @@ -133,8 +163,38 @@ infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &position_ auto attn_output = attn_->forward(q_reshaped, k_reshaped, v_reshaped); // 6. Project output - auto output = o_proj_->forward(attn_output); - return output; + if (enable_preallocated_workspace_) { + auto o_output = max_o_output_->narrow({{0, 0, seq_len}})->view({1, seq_len, hidden_size_}); + o_proj_->forward_(o_output, attn_output); + return o_output; + } + return o_proj_->forward(attn_output); +} + +void Attention::_initialize_preallocated_workspace() { + const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); + auto &preallocated_workspace = infinilm::global_state::get_forward_context().preallocated_workspace; + const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; + + const std::string attention_cache_key = std::string("Attention_max_num_batched_tokens_") + + std::to_string(max_num_batched_tokens) + "_rank_qkv_output_size_" + + std::to_string(rank_qkv_output_size_) + "_hidden_size_" + + std::to_string(hidden_size_) + "_dtype_" + + infinicore::toString(dtype_) + "_device_" + + device_.toString(); + + size_t max_output_size = std::max(rank_qkv_output_size_, hidden_size_); + if (preallocated_workspace.find(attention_cache_key) == preallocated_workspace.end()) { + auto attention_buffer = infinicore::Tensor::empty({max_num_batched_tokens * max_output_size}, dtype_, device_); + preallocated_workspace[attention_cache_key] = attention_buffer; + } + + auto attention_buffer = preallocated_workspace.at(attention_cache_key); + const auto attention_buffer_shape = attention_buffer->shape(); + ASSERT(attention_buffer_shape[0] == max_num_batched_tokens * max_output_size); + + max_qkv_output_ = attention_buffer->narrow({{0, 0, max_num_batched_tokens * rank_qkv_output_size_}})->view({max_num_batched_tokens, rank_qkv_output_size_}); + max_o_output_ = attention_buffer->narrow({{0, 0, max_num_batched_tokens * hidden_size_}})->view({max_num_batched_tokens, hidden_size_}); } void init_kv_cache_quant_params(std::function register_fn, diff --git a/csrc/layers/attention/attention.hpp b/csrc/layers/attention/attention.hpp index 31f0d1fa4..6bcef9da0 100644 --- a/csrc/layers/attention/attention.hpp +++ b/csrc/layers/attention/attention.hpp @@ -5,6 +5,8 @@ #include "../../global_state/global_state.hpp" #include "../linear/linear.hpp" #include "backends/attention_layer.hpp" +#include "infinicore/device.hpp" +#include "infinicore/dtype.hpp" #include "infinicore/nn/module.hpp" #include "infinicore/nn/rope.hpp" #include "infinicore/tensor.hpp" @@ -37,6 +39,8 @@ class Attention : public infinicore::nn::Module { infinicore::Tensor forward_paged_(const infinicore::Tensor &positions, const infinicore::Tensor &hidden_states) const; + void _initialize_preallocated_workspace(); + protected: std::shared_ptr qkv_proj_; std::shared_ptr o_proj_; @@ -49,13 +53,24 @@ class Attention : public infinicore::nn::Module { size_t num_key_value_heads_; size_t hidden_size_; size_t head_dim_; + infinicore::Device device_; + infinicore::DataType dtype_; // For off-line kv cache quantization INFINICORE_NN_PARAMETER(kv_cache_k_scale); INFINICORE_NN_PARAMETER(kv_cache_v_scale); + +private: + bool enable_preallocated_workspace_{false}; + + size_t rank_qkv_output_size_; + + // preallocated workspace for Attention + infinicore::Tensor max_qkv_output_; + infinicore::Tensor max_o_output_; }; void init_kv_cache_quant_params(std::function register_fn, - const infinicore::Device &device, - infinicore::nn::Parameter &kv_cache_k_scale, - infinicore::nn::Parameter &kv_cache_v_scale); + const infinicore::Device &device, + infinicore::nn::Parameter &kv_cache_k_scale, + infinicore::nn::Parameter &kv_cache_v_scale); } // namespace infinilm::layers::attention diff --git a/csrc/layers/attention/backends/attention_layer.cpp b/csrc/layers/attention/backends/attention_layer.cpp index fcaefa292..e5e39c10f 100644 --- a/csrc/layers/attention/backends/attention_layer.cpp +++ b/csrc/layers/attention/backends/attention_layer.cpp @@ -9,16 +9,17 @@ AttentionLayer::AttentionLayer(size_t num_heads, size_t layer_idx, infinicore::Tensor k_scale, infinicore::Tensor v_scale, - ::infinilm::backends::AttentionBackend attn_backend) : k_scale_(k_scale), v_scale_(v_scale), layer_idx_(layer_idx), attn_backend_(attn_backend) { + ::infinilm::backends::AttentionBackend attn_backend, + const infinicore::Device &device) : k_scale_(k_scale), v_scale_(v_scale), layer_idx_(layer_idx), attn_backend_(attn_backend) { switch (attn_backend) { case ::infinilm::backends::AttentionBackend::STATIC_ATTN: - attn_backend_impl_ = std::make_shared(num_heads, head_size, scale, num_kv_heads, layer_idx); + attn_backend_impl_ = std::make_shared(num_heads, head_size, scale, num_kv_heads, layer_idx, device); break; case ::infinilm::backends::AttentionBackend::PAGED_ATTN: - attn_backend_impl_ = std::make_shared(num_heads, head_size, scale, num_kv_heads, layer_idx); + attn_backend_impl_ = std::make_shared(num_heads, head_size, scale, num_kv_heads, layer_idx, device); break; case ::infinilm::backends::AttentionBackend::FLASH_ATTN: - attn_backend_impl_ = std::make_shared(num_heads, head_size, scale, num_kv_heads, layer_idx); + attn_backend_impl_ = std::make_shared(num_heads, head_size, scale, num_kv_heads, layer_idx, device); break; default: throw std::runtime_error("infinilm::layers::attention::AttentionLayer: unsupported attention backend"); diff --git a/csrc/layers/attention/backends/attention_layer.hpp b/csrc/layers/attention/backends/attention_layer.hpp index 874110629..d83a79796 100644 --- a/csrc/layers/attention/backends/attention_layer.hpp +++ b/csrc/layers/attention/backends/attention_layer.hpp @@ -31,7 +31,8 @@ class AttentionLayer { size_t layer_idx, infinicore::Tensor k_scale, infinicore::Tensor v_scale, - ::infinilm::backends::AttentionBackend attention_backend); + ::infinilm::backends::AttentionBackend attention_backend, + const infinicore::Device &device); infinicore::Tensor forward(infinicore::Tensor &query, infinicore::Tensor &key, diff --git a/csrc/layers/attention/backends/flash_attn.cpp b/csrc/layers/attention/backends/flash_attn.cpp index ec7e37722..3437cc505 100644 --- a/csrc/layers/attention/backends/flash_attn.cpp +++ b/csrc/layers/attention/backends/flash_attn.cpp @@ -1,9 +1,11 @@ #include "flash_attn.hpp" +#include "../../../global_state/global_state.hpp" #include "../../../utils.hpp" #include "infinicore/ops.hpp" #include "infinicore/ops/mha_kvcache.hpp" #include "infinicore/ops/mha_varlen.hpp" +#include namespace infinilm::layers::attention::backends { @@ -11,19 +13,29 @@ FlashAttentionImpl::FlashAttentionImpl(size_t num_heads, size_t head_size, float scale, size_t num_kv_heads, - size_t layer_idx) + size_t layer_idx, + const infinicore::Device &device) : num_heads_(num_heads), head_size_(head_size), scale_(scale), num_kv_heads_(num_kv_heads), layer_idx_(layer_idx), - head_dim_(head_size) { + head_dim_(head_size), + device_(device) { const infinilm::global_state::InfinilmConfig &infinilm_config = infinilm::global_state::get_infinilm_config(); if (!infinilm_config.model_config) { throw std::runtime_error("infinilm::layers::attention::backends::FlashAttentionImpl: model_config is null"); } - max_position_embeddings_ = infinilm_config.model_config->get("max_position_embeddings"); + + const auto &model_config = infinilm_config.model_config; + dtype_ = model_config->get_dtype(); + max_position_embeddings_ = model_config->get("max_position_embeddings"); + + enable_preallocated_workspace_ = infinilm_config.enable_preallocated_workspace; + if (enable_preallocated_workspace_) { + this->_initialize_preallocated_workspace(); + } } infinicore::Tensor FlashAttentionImpl::forward(const AttentionLayer &layer, @@ -48,8 +60,13 @@ infinicore::Tensor FlashAttentionImpl::forward(const AttentionLayer &layer, bool is_prefill = (seq_len != total_sequence_lengths.value()->shape()[0]); // 2. Compute attention - infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_heads_, head_dim_}, query->dtype(), query->device()); + infinicore::Tensor attn_output; if (is_prefill) { + if (enable_preallocated_workspace_) { + attn_output = max_attn_output_->narrow({{0, 0, seq_len}}); + } else { + attn_output = infinicore::Tensor::empty({seq_len, num_heads_, head_dim_}, dtype_, device_); + } infinicore::op::mha_varlen_( attn_output, query, @@ -99,4 +116,27 @@ std::tuple FlashAttentionImpl::do_kv_cac return {k_cache_layer, v_cache_layer}; } +void FlashAttentionImpl::_initialize_preallocated_workspace() { + const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); + auto &preallocated_workspace = infinilm::global_state::get_forward_context().preallocated_workspace; + const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; + + const std::string cache_key = std::string("FlashAttentionImpl_max_num_batched_tokens_") + + std::to_string(max_num_batched_tokens) + "_num_heads_" + + std::to_string(num_heads_) + "_head_dim_" + + std::to_string(head_dim_) + "_dtype_" + + infinicore::toString(dtype_) + "_device_" + + device_.toString(); + + if (preallocated_workspace.find(cache_key) == preallocated_workspace.end()) { + auto flash_attention_impl_buffer = infinicore::Tensor::empty({max_num_batched_tokens, num_heads_, head_dim_}, dtype_, device_); + preallocated_workspace[cache_key] = flash_attention_impl_buffer; + } + + auto flash_attention_impl_buffer = preallocated_workspace.at(cache_key); + const auto buffer_shape = flash_attention_impl_buffer->shape(); + ASSERT(buffer_shape[0] == max_num_batched_tokens && buffer_shape[1] == num_heads_ && buffer_shape[2] == head_dim_); + + max_attn_output_ = flash_attention_impl_buffer; +} } // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/attention/backends/flash_attn.hpp b/csrc/layers/attention/backends/flash_attn.hpp index 93f61e8ba..47fc76216 100644 --- a/csrc/layers/attention/backends/flash_attn.hpp +++ b/csrc/layers/attention/backends/flash_attn.hpp @@ -1,6 +1,8 @@ #pragma once #include "../../../global_state/global_state.hpp" +#include "infinicore/device.hpp" +#include "infinicore/dtype.hpp" #include "infinicore/tensor.hpp" #include @@ -16,7 +18,8 @@ class FlashAttentionImpl { size_t head_size, float scale, size_t num_kv_heads, - size_t layer_idx); + size_t layer_idx, + const infinicore::Device &device); /** * @brief Forward pass with FlashAttention. @@ -43,6 +46,10 @@ class FlashAttentionImpl { const infinicore::Tensor slot_mapping) const; private: + void _initialize_preallocated_workspace(); + + bool enable_preallocated_workspace_{false}; + size_t num_heads_; size_t head_size_; float scale_; @@ -50,5 +57,11 @@ class FlashAttentionImpl { size_t layer_idx_; size_t head_dim_; // Note: head_dim equals to head_size size_t max_position_embeddings_; + infinicore::Device device_; + infinicore::DataType dtype_; + + // preallocated workspace for FlashAttentionImpl + infinicore::Tensor max_attn_output_; }; + } // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/attention/backends/paged_attn.cpp b/csrc/layers/attention/backends/paged_attn.cpp index a0ad70afe..4e5e3f0ea 100644 --- a/csrc/layers/attention/backends/paged_attn.cpp +++ b/csrc/layers/attention/backends/paged_attn.cpp @@ -1,21 +1,40 @@ #include "paged_attn.hpp" +#include "../../../global_state/global_state.hpp" #include "../../../utils.hpp" +#include "attention_layer.hpp" #include "infinicore/ops.hpp" +#include + namespace infinilm::layers::attention::backends { PagedAttentionImpl::PagedAttentionImpl(size_t num_heads, size_t head_size, float scale, size_t num_kv_heads, - size_t layer_idx) + size_t layer_idx, + const infinicore::Device &device) : num_heads_(num_heads), head_size_(head_size), scale_(scale), num_kv_heads_(num_kv_heads), layer_idx_(layer_idx), - head_dim_(head_size) {} + head_dim_(head_size), + device_(device) { + + const infinilm::global_state::InfinilmConfig &infinilm_config = infinilm::global_state::get_infinilm_config(); + if (!infinilm_config.model_config) { + throw std::runtime_error("infinilm::layers::attention::backends::PagedAttentionImpl: model_config is null"); + } + + dtype_ = infinilm_config.model_config->get_dtype(); + + enable_preallocated_workspace_ = infinilm_config.enable_preallocated_workspace; + if (enable_preallocated_workspace_) { + this->_initialize_preallocated_workspace(); + } +} infinicore::Tensor PagedAttentionImpl::forward(const AttentionLayer &layer, const infinicore::Tensor &query, @@ -37,7 +56,12 @@ infinicore::Tensor PagedAttentionImpl::forward(const AttentionLayer &layer, bool is_prefill = (seq_len != total_sequence_lengths.value()->shape()[0]); // 2. Compute attention - infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_heads_, head_dim_}, query->dtype(), query->device()); + infinicore::Tensor attn_output; + if (enable_preallocated_workspace_) { + attn_output = max_attn_output_->narrow({{0, 0, seq_len}}); + } else { + attn_output = infinicore::Tensor::empty({seq_len, num_heads_, head_dim_}, dtype_, device_); + } if (is_prefill) { infinicore::op::paged_attention_prefill_( attn_output, @@ -80,4 +104,29 @@ std::tuple PagedAttentionImpl::do_kv_cac return {k_cache_layer, v_cache_layer}; } + +void PagedAttentionImpl::_initialize_preallocated_workspace() { + const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); + auto &preallocated_workspace = infinilm::global_state::get_forward_context().preallocated_workspace; + const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; + + const std::string cache_key = std::string("PagedAttentionImpl_max_num_batched_tokens_") + + std::to_string(max_num_batched_tokens) + "_num_heads_" + + std::to_string(num_heads_) + "_head_dim_" + + std::to_string(head_dim_) + "_dtype_" + + infinicore::toString(dtype_) + "_device_" + + device_.toString(); + + if (preallocated_workspace.find(cache_key) == preallocated_workspace.end()) { + auto paged_attention_impl_buffer = infinicore::Tensor::empty({max_num_batched_tokens, num_heads_, head_dim_}, dtype_, device_); + preallocated_workspace[cache_key] = paged_attention_impl_buffer; + } + + auto paged_attention_impl_buffer = preallocated_workspace.at(cache_key); + const auto buffer_shape = paged_attention_impl_buffer->shape(); + ASSERT(buffer_shape[0] == max_num_batched_tokens && buffer_shape[1] == num_heads_ && buffer_shape[2] == head_dim_); + + max_attn_output_ = paged_attention_impl_buffer; +} + } // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/attention/backends/paged_attn.hpp b/csrc/layers/attention/backends/paged_attn.hpp index 4f53ea573..7433992bd 100644 --- a/csrc/layers/attention/backends/paged_attn.hpp +++ b/csrc/layers/attention/backends/paged_attn.hpp @@ -1,6 +1,8 @@ #pragma once #include "../../../global_state/global_state.hpp" +#include "infinicore/device.hpp" +#include "infinicore/dtype.hpp" #include "infinicore/tensor.hpp" #include @@ -16,7 +18,8 @@ class PagedAttentionImpl { size_t head_size, float scale, size_t num_kv_heads, - size_t layer_idx); + size_t layer_idx, + const infinicore::Device &device); /** * @brief Forward pass with PagedAttention. @@ -43,11 +46,21 @@ class PagedAttentionImpl { const infinicore::Tensor slot_mapping) const; private: + void _initialize_preallocated_workspace(); + + bool enable_preallocated_workspace_{false}; + size_t num_heads_; size_t head_size_; float scale_; size_t num_kv_heads_; size_t layer_idx_; size_t head_dim_; // Note: head_dim equals to head_size + infinicore::Device device_; + infinicore::DataType dtype_; + + // preallocated workspace for PagedAttentionImpl + infinicore::Tensor max_attn_output_; }; + } // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/attention/backends/static_attn.cpp b/csrc/layers/attention/backends/static_attn.cpp index 2d1b7e11a..668f4c218 100644 --- a/csrc/layers/attention/backends/static_attn.cpp +++ b/csrc/layers/attention/backends/static_attn.cpp @@ -11,7 +11,8 @@ StaticAttentionImpl::StaticAttentionImpl(size_t num_heads, size_t head_size, float scale, size_t num_kv_heads, - size_t layer_idx) + size_t layer_idx, + const infinicore::Device & /*device*/) : num_heads_(num_heads), head_size_(head_size), scale_(scale), diff --git a/csrc/layers/attention/backends/static_attn.hpp b/csrc/layers/attention/backends/static_attn.hpp index 849d87928..00af4391e 100644 --- a/csrc/layers/attention/backends/static_attn.hpp +++ b/csrc/layers/attention/backends/static_attn.hpp @@ -18,7 +18,8 @@ class StaticAttentionImpl { size_t head_size, float scale, size_t num_kv_heads, - size_t layer_idx); + size_t layer_idx, + const infinicore::Device &device); infinicore::Tensor forward(const AttentionLayer &layer, infinicore::Tensor &q_reshaped, // query diff --git a/csrc/layers/causal_lm_templates/text_causal_lm.hpp b/csrc/layers/causal_lm_templates/text_causal_lm.hpp index eb4f2b47f..c2a7f8c66 100644 --- a/csrc/layers/causal_lm_templates/text_causal_lm.hpp +++ b/csrc/layers/causal_lm_templates/text_causal_lm.hpp @@ -1,8 +1,12 @@ #pragma once +#include "../../global_state/global_state.hpp" #include "../../models/infinilm_model.hpp" +#include "../../utils.hpp" #include "../linear/linear.hpp" #include "infinicore/device.hpp" +#include "infinicore/dtype.hpp" +#include namespace infinilm::layers::causal_lm_templates { @@ -28,15 +32,21 @@ class TextCausalLM : public InfinilmModel { * @param device: Device to create tensors on */ TextCausalLM(std::shared_ptr model_config, - const infinicore::Device &device) { + const infinicore::Device &device) + : device_(device), + dtype_(model_config->get_dtype()) { model_config_ = model_config; size_t hidden_size = model_config->get("hidden_size"); - size_t vocab_size = model_config->get("vocab_size"); - const auto &dtype{model_config->get_dtype()}; + vocab_size_ = model_config->get("vocab_size"); model_ = this->register_module("model", model_config, device); - lm_head_ = this->register_module("lm_head", hidden_size, vocab_size, false, dtype, device); + lm_head_ = this->register_module("lm_head", hidden_size, vocab_size_, false, dtype_, device_); + + enable_preallocated_workspace_ = infinilm::global_state::get_infinilm_config().enable_preallocated_workspace; + if (enable_preallocated_workspace_) { + this->_initialize_preallocated_workspace(); + } } /** @@ -44,7 +54,18 @@ class TextCausalLM : public InfinilmModel { */ Output forward(const Input &input) const override { auto hidden_states = model_->forward(input); - auto logits = lm_head_->forward(hidden_states); + infinicore::Tensor logits; + + if (enable_preallocated_workspace_) { + const auto shape = hidden_states->shape(); + const size_t bs = shape[0]; + const size_t seq_len = shape[1]; + logits = max_logits_->narrow({{0, 0, bs * seq_len}})->view({bs, seq_len, vocab_size_}); + lm_head_->forward_(logits, hidden_states); + } else { + logits = lm_head_->forward(hidden_states); + } + return {logits}; } @@ -55,8 +76,41 @@ class TextCausalLM : public InfinilmModel { Model &model() { return *model_; } protected: + size_t vocab_size_; + infinicore::Device device_; + infinicore::DataType dtype_; + INFINICORE_NN_MODULE(Model, model); INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, lm_head); + +private: + void _initialize_preallocated_workspace() { + const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); + auto &preallocated_workspace = infinilm::global_state::get_forward_context().preallocated_workspace; + const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; + + const std::string text_causal_lm_cache_key = std::string("TextCausalLM_max_num_batched_tokens_") + + std::to_string(max_num_batched_tokens) + "_vocab_size_" + + std::to_string(vocab_size_) + "_dtype_" + + infinicore::toString(dtype_) + "_device_" + + device_.toString(); + + if (preallocated_workspace.find(text_causal_lm_cache_key) == preallocated_workspace.end()) { + auto logits_buffer = infinicore::Tensor::empty({max_num_batched_tokens, vocab_size_}, dtype_, device_); + preallocated_workspace[text_causal_lm_cache_key] = logits_buffer; + } + + auto logits_buffer = preallocated_workspace.at(text_causal_lm_cache_key); + const auto logits_buffer_shape = logits_buffer->shape(); + ASSERT(logits_buffer_shape[0] == max_num_batched_tokens && logits_buffer_shape[1] == vocab_size_); + + max_logits_ = logits_buffer; + } + + bool enable_preallocated_workspace_{false}; + + // preallocated workspace for TextCausalLM + infinicore::Tensor max_logits_; }; } // namespace infinilm::layers::causal_lm_templates diff --git a/csrc/layers/causal_lm_templates/text_model.hpp b/csrc/layers/causal_lm_templates/text_model.hpp index 62a52798b..5eca0b029 100644 --- a/csrc/layers/causal_lm_templates/text_model.hpp +++ b/csrc/layers/causal_lm_templates/text_model.hpp @@ -1,11 +1,16 @@ #pragma once #include "../../config/model_config.hpp" +#include "../../global_state/global_state.hpp" #include "../../models/infinilm_model.hpp" +#include "../../utils.hpp" +#include "infinicore/device.hpp" +#include "infinicore/dtype.hpp" #include "infinicore/nn/embedding.hpp" #include "infinicore/nn/rmsnorm.hpp" #include "infinicore/tensor.hpp" #include +#include #include namespace infinilm::layers::causal_lm_templates { @@ -24,30 +29,46 @@ template class TextModel : public infinicore::nn::Module { public: TextModel(std::shared_ptr model_config, - const infinicore::Device &device) { - const auto &dtype{model_config->get_dtype()}; - size_t vocab_size = model_config->get("vocab_size"); - size_t hidden_size = model_config->get("hidden_size"); + const infinicore::Device &device) + : device_(device), + dtype_(model_config->get_dtype()) { + vocab_size_ = model_config->get("vocab_size"); + hidden_size_ = model_config->get("hidden_size"); size_t max_position_embeddings = model_config->get("max_position_embeddings"); size_t num_hidden_layers = model_config->get("num_hidden_layers"); double rope_theta = model_config->get("rope_theta"); double rms_norm_eps = model_config->get("rms_norm_eps"); - embed_tokens_ = this->register_module("embed_tokens", vocab_size, hidden_size, std::nullopt, dtype, device); + embed_tokens_ = this->register_module("embed_tokens", vocab_size_, hidden_size_, std::nullopt, dtype_, device_); layers_.reserve(num_hidden_layers); for (size_t i = 0; i < num_hidden_layers; ++i) { - layers_.push_back(this->register_module("layers." + std::to_string(i), model_config, i, device)); + layers_.push_back(this->register_module("layers." + std::to_string(i), model_config, i, device_)); } - norm_ = this->register_module("norm", hidden_size, rms_norm_eps, dtype, device); + norm_ = this->register_module("norm", hidden_size_, rms_norm_eps, dtype_, device_); + + enable_preallocated_workspace_ = infinilm::global_state::get_infinilm_config().enable_preallocated_workspace; + if (enable_preallocated_workspace_) { + this->_initialize_preallocated_workspace(); + } } infinicore::Tensor forward(const infinilm::InfinilmModel::Input &input) const { auto input_ids = input.input_ids.value(); auto positions = input.position_ids.value(); + // 1. Embed tokens: input_ids -> [batch, seq_len, hidden_size] - auto hidden_states = embed_tokens_->forward(input_ids); + infinicore::Tensor hidden_states; + if (enable_preallocated_workspace_) { + const auto shape = input_ids->shape(); + const size_t bs = shape[0]; + const size_t seq_len = shape[1]; + hidden_states = max_hidden_states_->narrow({{0, 0, bs * seq_len}})->view({bs, seq_len, hidden_size_}); + embed_tokens_->forward_(hidden_states, input_ids); + } else { + hidden_states = embed_tokens_->forward(input_ids); + } // 2. Process through all decoder layers size_t num_layers = layers_.size(); @@ -64,6 +85,7 @@ class TextModel : public infinicore::nn::Module { } infinicore::Tensor forward_naive(const infinilm::InfinilmModel::Input &input) const { + // Don't use preallocated workspace in forward_naive function. auto input_ids = input.input_ids.value(); auto positions = input.position_ids.value(); auto hidden_states = embed_tokens_->forward(input_ids); @@ -78,6 +100,7 @@ class TextModel : public infinicore::nn::Module { infinicore::Tensor forward_embeds(const infinicore::Tensor &inputs_embeds, const infinicore::Tensor &position_ids) const { + // Don't use preallocated workspace in forward_embeds function. auto hidden_states = inputs_embeds; // Process through all decoder layers @@ -98,10 +121,41 @@ class TextModel : public infinicore::nn::Module { return embed_tokens_->forward(input_ids); } +private: + void _initialize_preallocated_workspace() { + const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); + auto &preallocated_workspace = infinilm::global_state::get_forward_context().preallocated_workspace; + const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; + + const std::string text_model_cache_key = std::string("TextModel_max_num_batched_tokens_") + std::to_string(max_num_batched_tokens) + "_hidden_size_" + std::to_string(hidden_size_) + "_dtype_" + infinicore::toString(dtype_) + "_device_" + device_.toString(); + + if (preallocated_workspace.find(text_model_cache_key) == preallocated_workspace.end()) { + auto text_model_buffer = infinicore::Tensor::empty({max_num_batched_tokens, hidden_size_}, dtype_, device_); + preallocated_workspace[text_model_cache_key] = text_model_buffer; + } + + auto text_model_buffer = preallocated_workspace.at(text_model_cache_key); + const auto text_model_buffer_shape = text_model_buffer->shape(); + ASSERT(text_model_buffer_shape[0] == max_num_batched_tokens && text_model_buffer_shape[1] == hidden_size_); + + max_hidden_states_ = text_model_buffer; + } + protected: + size_t vocab_size_; + size_t hidden_size_; + infinicore::Device device_; + infinicore::DataType dtype_; + INFINICORE_NN_MODULE(infinicore::nn::Embedding, embed_tokens); INFINICORE_NN_MODULE_VEC(DecoderLayer, layers); INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, norm); + +private: + bool enable_preallocated_workspace_{false}; + + // preallocated workspace for TextModel + infinicore::Tensor max_hidden_states_; }; } // namespace infinilm::layers::causal_lm_templates diff --git a/csrc/layers/linear/base_linear.cpp b/csrc/layers/linear/base_linear.cpp index eebc482c4..25b292600 100644 --- a/csrc/layers/linear/base_linear.cpp +++ b/csrc/layers/linear/base_linear.cpp @@ -43,10 +43,25 @@ infinicore::Tensor BaseLinear::compute_linear(infinicore::Tensor &input) const { return quantization_->forward(params, input, has_bias_, alpha_); } +void BaseLinear::compute_linear_(infinicore::Tensor &output, infinicore::Tensor &input) const { + // Build params map from direct parameters only (not state_dict which uses a + // static local and is not thread-safe across RankWorker threads). + infinilm::quantization::ParamsMap params; + for (const auto &[name, param] : parameters_) { + params[name] = static_cast(param); + } + + quantization_->forward_(output, params, input, has_bias_, alpha_); +} + infinicore::Tensor BaseLinear::forward(infinicore::Tensor &input) const { return compute_linear(input); } +void BaseLinear::forward_(infinicore::Tensor &output, infinicore::Tensor &input) const { + compute_linear_(output, input); +} + infinicore::Tensor BaseLinear::forward(infinicore::Tensor &input, infinicore::Tensor &residual) const { auto output = compute_linear(input); infinicore::op::add_(output, output, residual); @@ -60,7 +75,9 @@ void BaseLinear::process_weights_after_loading() { } auto new_quant = quantization_->process_weights_after_loading(params, device_); - if (!new_quant) return; + if (!new_quant) { + return; + } for (auto &[name, param] : parameters_) { param = infinicore::nn::Parameter(); @@ -68,7 +85,9 @@ void BaseLinear::process_weights_after_loading() { for (const auto &[name, tensor] : params) { auto it = parameters_.find(name); - if (it == parameters_.end()) continue; + if (it == parameters_.end()) { + continue; + } it->second = infinicore::nn::Parameter(tensor); } @@ -79,43 +98,61 @@ void BaseLinear::process_weights_after_loading() { infinicore::Tensor BaseLinear::weight() const { auto it = parameters_.find("weight"); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } it = parameters_.find("qweight"); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } return infinicore::Tensor(); } infinicore::Tensor BaseLinear::bias() const { auto it = parameters_.find("bias"); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } return infinicore::Tensor(); } infinicore::Tensor BaseLinear::weight_scale() const { auto it = parameters_.find("weight_scale"); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } it = parameters_.find("scales"); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } return infinicore::Tensor(); } infinicore::Tensor BaseLinear::weight_zeros() const { auto it = parameters_.find("weight_zeros"); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } it = parameters_.find("qzeros"); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } return infinicore::Tensor(); } infinicore::Tensor BaseLinear::gidx() const { auto it = parameters_.find("g_idx"); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } return infinicore::Tensor(); } infinicore::Tensor BaseLinear::get_param(const std::string &name) const { auto it = parameters_.find(name); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } return infinicore::Tensor(); } diff --git a/csrc/layers/linear/base_linear.hpp b/csrc/layers/linear/base_linear.hpp index a304f4573..52ac27c40 100644 --- a/csrc/layers/linear/base_linear.hpp +++ b/csrc/layers/linear/base_linear.hpp @@ -1,8 +1,8 @@ #pragma once -#include "infinicore/ops.hpp" #include "../quantization/quantization.hpp" #include "infinicore/nn/module.hpp" +#include "infinicore/ops.hpp" #include #include @@ -23,6 +23,8 @@ class BaseLinear : public infinicore::nn::Module { // Forward pass: output = input @ weight.T + bias infinicore::Tensor forward(infinicore::Tensor &input) const; + void forward_(infinicore::Tensor &output, infinicore::Tensor &input) const; + // Forward pass with residual connection infinicore::Tensor forward(infinicore::Tensor &input, infinicore::Tensor &residual) const; @@ -57,6 +59,7 @@ class BaseLinear : public infinicore::nn::Module { protected: infinicore::Tensor compute_linear(infinicore::Tensor &input) const; + void compute_linear_(infinicore::Tensor &output, infinicore::Tensor &input) const; size_t in_features_; size_t out_features_; diff --git a/csrc/layers/linear/fused_linear.cpp b/csrc/layers/linear/fused_linear.cpp index 04d0ad316..6cc0c517b 100644 --- a/csrc/layers/linear/fused_linear.cpp +++ b/csrc/layers/linear/fused_linear.cpp @@ -31,14 +31,14 @@ QKVParallelLinear::QKVParallelLinear(size_t hidden_size, const infinicore::Device &device, engine::distributed::RankInfo rank_info) : infinilm::nn::ColumnParallelLinear( - hidden_size, - calculate_out_feature_size(num_q_head, q_dim, num_k_head, k_dim, num_v_head, v_dim, rank_info), - quantization, - (q_bias || k_bias || v_bias), - dtype, - device, - rank_info.tp_rank, - rank_info.tp_size), + hidden_size, + calculate_out_feature_size(num_q_head, q_dim, num_k_head, k_dim, num_v_head, v_dim, rank_info), + quantization, + (q_bias || k_bias || v_bias), + dtype, + device, + rank_info.tp_rank, + rank_info.tp_size), q_dim_(q_dim), k_dim_(k_dim), v_dim_(v_dim), @@ -70,6 +70,17 @@ QKVParallelLinear::forward_split(infinicore::Tensor &input) { return std::make_tuple(q_out, k_out, v_out); } +std::tuple +QKVParallelLinear::forward_split_(infinicore::Tensor &output, infinicore::Tensor &input) { + this->forward_(output, input); + + auto q_out = output->narrow({{2, 0, q_out_size_}}); + auto k_out = output->narrow({{2, q_out_size_, k_out_size_}}); + auto v_out = output->narrow({{2, q_out_size_ + k_out_size_, v_out_size_}}); + + return std::make_tuple(q_out, k_out, v_out); +} + bool QKVParallelLinear::has_q_bias() const { return q_bias_; } bool QKVParallelLinear::has_k_bias() const { return k_bias_; } bool QKVParallelLinear::has_v_bias() const { return v_bias_; } @@ -134,6 +145,15 @@ std::tuple GateUpParallelLinear::forward return std::make_tuple(gate_output, up_output); } +std::tuple +GateUpParallelLinear::forward_split_(infinicore::Tensor &output, infinicore::Tensor &input) { + this->forward_(output, input); + auto cols = output->shape()[2]; + auto gate_output = output->narrow({{2, 0, cols / 2}}); + auto up_output = output->narrow({{2, cols / 2, cols / 2}}); + return std::make_tuple(gate_output, up_output); +} + bool GateUpParallelLinear::has_gate_bias() const { return gate_bias_; } bool GateUpParallelLinear::has_up_bias() const { return up_bias_; } diff --git a/csrc/layers/linear/fused_linear.hpp b/csrc/layers/linear/fused_linear.hpp index 92ec3b909..35a700ce2 100644 --- a/csrc/layers/linear/fused_linear.hpp +++ b/csrc/layers/linear/fused_linear.hpp @@ -1,7 +1,7 @@ #pragma once #include "../../engine/distributed/communication_group.hpp" -#include "linear.hpp" #include "../quantization/quantization.hpp" +#include "linear.hpp" #include namespace infinilm::layers::linear { @@ -43,6 +43,9 @@ class QKVParallelLinear : public infinilm::nn::ColumnParallelLinear { std::tuple forward_split(infinicore::Tensor &input); + std::tuple + forward_split_(infinicore::Tensor &output, infinicore::Tensor &input); + bool has_q_bias() const; bool has_k_bias() const; bool has_v_bias() const; @@ -108,6 +111,9 @@ class GateUpParallelLinear : public infinilm::nn::ColumnParallelLinear { std::tuple forward_split(infinicore::Tensor &input); + std::tuple + forward_split_(infinicore::Tensor &output, infinicore::Tensor &input); + bool has_gate_bias() const; bool has_up_bias() const; diff --git a/csrc/layers/linear/linear.cpp b/csrc/layers/linear/linear.cpp index 84982409f..f73abbb29 100644 --- a/csrc/layers/linear/linear.cpp +++ b/csrc/layers/linear/linear.cpp @@ -93,6 +93,14 @@ infinicore::Tensor RowParallelLinear::forward(infinicore::Tensor &input) const { return output; } +void RowParallelLinear::forward_(infinicore::Tensor &output, infinicore::Tensor &input) const { + BaseLinear::forward_(output, input); + + if ((tp_size_ > 1) && (communicator_ != nullptr)) { + infinicore::op::distributed::allreduce_(output, output, INFINICCL_SUM, communicator_); + } +} + std::string RowParallelLinear::extra_repr() const { return "RowParallelLinear(in_features=" + std::to_string(in_features_) + ", out_features=" + std::to_string(out_features_) + ", bias=" + (has_bias_ ? "true" : "false") + ", dtype=" + std::to_string(static_cast(dtype_)) + ")"; } diff --git a/csrc/layers/linear/linear.hpp b/csrc/layers/linear/linear.hpp index 566cee77c..abae08f66 100644 --- a/csrc/layers/linear/linear.hpp +++ b/csrc/layers/linear/linear.hpp @@ -70,6 +70,9 @@ class RowParallelLinear : public BaseLinear { infinicclComm_t communicator = nullptr); infinicore::Tensor forward(infinicore::Tensor &input) const; + + void forward_(infinicore::Tensor &output, infinicore::Tensor &input) const; + std::string extra_repr() const; protected: diff --git a/csrc/layers/mlp/mlp.cpp b/csrc/layers/mlp/mlp.cpp index f7604c505..c65b60f5f 100644 --- a/csrc/layers/mlp/mlp.cpp +++ b/csrc/layers/mlp/mlp.cpp @@ -1,13 +1,16 @@ #include "mlp.hpp" #include "../../global_state/global_state.hpp" +#include "../../utils.hpp" #include "infinicore/ops.hpp" +#include namespace infinilm::layers::mlp { MLP::MLP(std::shared_ptr model_config, - const infinicore::Device &device) { + const infinicore::Device &device) + : device_(device), + dtype_(model_config->get_dtype()) { - const auto &dtype{model_config->get_dtype()}; hidden_size_ = model_config->get("hidden_size"); intermediate_size_ = model_config->get("intermediate_size"); use_bias_ = model_config->get_or("mlp_bias", false); @@ -20,13 +23,25 @@ MLP::MLP(std::shared_ptr model_config, auto register_fn = [this](const std::string &n, infinicore::nn::Parameter p) { this->register_parameter(n, std::move(p)); }; gate_up_proj_ = std::make_shared( hidden_size_, intermediate_size_, "gate_proj", "up_proj", register_fn, - quantization_method, use_bias_, dtype, device, rank_info); + quantization_method, use_bias_, dtype_, device_, rank_info); down_proj_ = this->register_module( "down_proj", intermediate_size_, hidden_size_, quantization_method, - use_bias_, dtype, device, tp_rank, tp_size, rank_info.comm); + use_bias_, dtype_, device_, tp_rank, tp_size, rank_info.comm); + + rank_gate_up_output_size_ = gate_up_proj_->out_features() / static_cast(tp_size); + rank_intermediate_size_ = rank_gate_up_output_size_ / 2; + + enable_preallocated_workspace_ = infinilm::global_state::get_infinilm_config().enable_preallocated_workspace; + if (enable_preallocated_workspace_) { + this->_initialize_preallocated_workspace(); + } } infinicore::Tensor MLP::forward(const infinicore::Tensor &hidden_states) const { + if (enable_preallocated_workspace_) { + return this->_forward_with_preallocated_workspace(hidden_states); + } + // 1. Project to gate and up auto hidden_states_mutable = hidden_states; auto [gate, up] = gate_up_proj_->forward_split(hidden_states_mutable); @@ -36,4 +51,60 @@ infinicore::Tensor MLP::forward(const infinicore::Tensor &hidden_states) const { auto output = down_proj_->forward(intermediate); return output; } + +infinicore::Tensor MLP::_forward_with_preallocated_workspace(const infinicore::Tensor &hidden_states) const { + const auto shape = hidden_states->shape(); + const size_t bs = shape[0]; + const size_t seq_len = shape[1]; + + // 1. Project to gate and up + auto hidden_states_mutable = hidden_states; + auto gate_up_output = max_gate_up_output_->narrow({{0, 0, bs * seq_len}})->view({bs, seq_len, rank_gate_up_output_size_}); + auto [gate, up] = gate_up_proj_->forward_split_(gate_up_output, hidden_states_mutable); + + // 2. Apply SwiGLU: silu(gate) * up + auto intermediate = max_intermediate_->narrow({{0, 0, bs * seq_len}})->view({bs, seq_len, rank_intermediate_size_}); + infinicore::op::swiglu_(intermediate, up, gate); + + // 3. Project down + auto down_output = max_down_output_->narrow({{0, 0, bs * seq_len}})->view({bs, seq_len, hidden_size_}); + down_proj_->forward_(down_output, intermediate); + return down_output; +} + +void MLP::_initialize_preallocated_workspace() { + const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); + auto &preallocated_workspace = infinilm::global_state::get_forward_context().preallocated_workspace; + const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; + + const std::string mlp_cache_key = std::string("MLP_max_num_batched_tokens_") + + std::to_string(max_num_batched_tokens) + "_rank_gate_up_output_size_" + + std::to_string(rank_gate_up_output_size_) + "_rank_intermediate_size_" + + std::to_string(rank_intermediate_size_) + "_hidden_size_" + + std::to_string(hidden_size_) + "_dtype_" + + infinicore::toString(dtype_) + "_device_" + + device_.toString(); + + auto align_up = [](size_t n, size_t alignment = 256) { + return (n + alignment - 1) & ~(alignment - 1); + }; + + size_t rank_gate_up_output_size_aligned = align_up(rank_gate_up_output_size_); + size_t rank_intermediate_size_aligned = align_up(rank_gate_up_output_size_aligned + rank_intermediate_size_); + size_t max_output_size = rank_intermediate_size_aligned + hidden_size_; + + if (preallocated_workspace.find(mlp_cache_key) == preallocated_workspace.end()) { + auto mlp_buffer = infinicore::Tensor::empty({max_num_batched_tokens * max_output_size}, dtype_, device_); + preallocated_workspace[mlp_cache_key] = mlp_buffer; + } + + auto mlp_buffer = preallocated_workspace.at(mlp_cache_key); + const auto buffer_shape = mlp_buffer->shape(); + ASSERT(buffer_shape[0] == max_num_batched_tokens * max_output_size); + + max_gate_up_output_ = mlp_buffer->narrow({{0, 0, max_num_batched_tokens * rank_gate_up_output_size_}})->view({max_num_batched_tokens, rank_gate_up_output_size_}); + max_intermediate_ = mlp_buffer->narrow({{0, max_num_batched_tokens * rank_gate_up_output_size_aligned, max_num_batched_tokens * rank_intermediate_size_}})->view({max_num_batched_tokens, rank_intermediate_size_}); + max_down_output_ = mlp_buffer->narrow({{0, max_num_batched_tokens * rank_intermediate_size_aligned, max_num_batched_tokens * hidden_size_}})->view({max_num_batched_tokens, hidden_size_}); +} + } // namespace infinilm::layers::mlp diff --git a/csrc/layers/mlp/mlp.hpp b/csrc/layers/mlp/mlp.hpp index 91349fe9b..0987f3649 100644 --- a/csrc/layers/mlp/mlp.hpp +++ b/csrc/layers/mlp/mlp.hpp @@ -2,7 +2,10 @@ #include "../../config/model_config.hpp" #include "../linear/linear.hpp" +#include "infinicore/device.hpp" +#include "infinicore/dtype.hpp" #include "infinicore/nn/module.hpp" +#include "infinicore/tensor.hpp" namespace infinilm::layers::mlp { @@ -51,6 +54,23 @@ class MLP : public infinicore::nn::Module { size_t hidden_size_; size_t intermediate_size_; bool use_bias_; + infinicore::Device device_; + infinicore::DataType dtype_; + +private: + infinicore::Tensor _forward_with_preallocated_workspace(const infinicore::Tensor &hidden_states) const; + + void _initialize_preallocated_workspace(); + + bool enable_preallocated_workspace_{false}; + + size_t rank_gate_up_output_size_; + size_t rank_intermediate_size_; + + // preallocated workspace for MLP + infinicore::Tensor max_gate_up_output_; + infinicore::Tensor max_intermediate_; + infinicore::Tensor max_down_output_; }; } // namespace infinilm::layers::mlp diff --git a/csrc/layers/quantization/awq.cpp b/csrc/layers/quantization/awq.cpp index 50e830f44..1c07c6dbe 100644 --- a/csrc/layers/quantization/awq.cpp +++ b/csrc/layers/quantization/awq.cpp @@ -52,6 +52,26 @@ infinicore::Tensor AWQ::forward( return infinicore::op::linear_w4a16_awq(input_contiguous->contiguous(), qweight, scales, qzeros, bias_opt); } +void AWQ::forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float /*alpha*/) const { + + auto input_contiguous = input->is_contiguous() ? input : input->contiguous(); + auto qweight = params.at("qweight"); + auto scales = params.at("scales"); + auto qzeros = params.at("qzeros"); + + std::optional bias_opt; + if (has_bias) { + bias_opt = params.at("bias"); + } + + infinicore::op::linear_w4a16_awq_(output, input_contiguous->contiguous(), qweight, scales, qzeros, bias_opt); +} + std::vector AWQ::split_params( const std::unordered_map ¶ms, const std::vector &splits, diff --git a/csrc/layers/quantization/awq.hpp b/csrc/layers/quantization/awq.hpp index 383e574aa..797092cb4 100644 --- a/csrc/layers/quantization/awq.hpp +++ b/csrc/layers/quantization/awq.hpp @@ -38,6 +38,13 @@ class AWQ : public BaseQuantization { bool has_bias, float alpha = 1.0f) const override; + void forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float alpha = 1.0f) const override; + std::vector split_params( const std::unordered_map ¶ms, const std::vector &splits, diff --git a/csrc/layers/quantization/base_quantization.hpp b/csrc/layers/quantization/base_quantization.hpp index 1fd261bdf..0ed82cb35 100644 --- a/csrc/layers/quantization/base_quantization.hpp +++ b/csrc/layers/quantization/base_quantization.hpp @@ -59,6 +59,14 @@ class BaseQuantization : public std::enable_shared_from_this { bool has_bias, float alpha = 1.0f) const = 0; + // In-place forward pass. + virtual void forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float alpha = 1.0f) const = 0; + // Dimension for fused-split (gate/up, q/k/v) of a column-parallel weight. // For NoneQuantization weight [out, in], split is on dim0. // For AWQ qweight [in, out/pack], split is on dim1. diff --git a/csrc/layers/quantization/compressed_tensors.cpp b/csrc/layers/quantization/compressed_tensors.cpp index 66a4a3ef6..45c46fb3e 100644 --- a/csrc/layers/quantization/compressed_tensors.cpp +++ b/csrc/layers/quantization/compressed_tensors.cpp @@ -43,6 +43,25 @@ infinicore::Tensor CompressedTensors::forward( return infinicore::op::linear_w8a8i8(input_contiguous->contiguous(), weight, weight_scale, bias_opt); } +void CompressedTensors::forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float /*alpha*/) const { + + auto input_contiguous = input->is_contiguous() ? input : input->contiguous(); + auto weight = params.at("weight"); + auto weight_scale = params.at("weight_scale"); + + std::optional bias_opt; + if (has_bias) { + bias_opt = params.at("bias"); + } + + infinicore::op::linear_w8a8i8_(output, input_contiguous->contiguous(), weight, weight_scale, bias_opt); +} + std::vector CompressedTensors::split_params( const std::unordered_map ¶ms, const std::vector &splits, diff --git a/csrc/layers/quantization/compressed_tensors.hpp b/csrc/layers/quantization/compressed_tensors.hpp index dcf65c2e0..2bac396aa 100644 --- a/csrc/layers/quantization/compressed_tensors.hpp +++ b/csrc/layers/quantization/compressed_tensors.hpp @@ -25,6 +25,13 @@ class CompressedTensors : public BaseQuantization { bool has_bias, float alpha = 1.0f) const override; + void forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float alpha = 1.0f) const override; + std::vector split_params( const std::unordered_map ¶ms, const std::vector &splits, diff --git a/csrc/layers/quantization/gptq.cpp b/csrc/layers/quantization/gptq.cpp index e7688be50..972aa130c 100644 --- a/csrc/layers/quantization/gptq.cpp +++ b/csrc/layers/quantization/gptq.cpp @@ -36,6 +36,16 @@ infinicore::Tensor GPTQ::forward( "Call process_weights_after_loading() first."); } +void GPTQ::forward_( + infinicore::Tensor & /*output*/, + const ParamsMap & /*params*/, + const infinicore::Tensor & /*input*/, + bool /*has_bias*/, + float /*alpha*/) const { + throw std::runtime_error("GPTQ_W4A16 must be converted to GPTQ_QY before forward pass. " + "Call process_weights_after_loading() first."); +} + std::shared_ptr GPTQ::process_weights_after_loading( ParamsMap ¶ms, const infinicore::Device &device) const { diff --git a/csrc/layers/quantization/gptq.hpp b/csrc/layers/quantization/gptq.hpp index 455dde2cc..598be78fb 100644 --- a/csrc/layers/quantization/gptq.hpp +++ b/csrc/layers/quantization/gptq.hpp @@ -34,6 +34,13 @@ class GPTQ : public BaseQuantization { bool has_bias, float alpha = 1.0f) const override; + void forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float alpha = 1.0f) const override; + std::vector split_params( const std::unordered_map ¶ms, const std::vector &splits, diff --git a/csrc/layers/quantization/gptq_qy.cpp b/csrc/layers/quantization/gptq_qy.cpp index 4098e452d..42ac9b26e 100644 --- a/csrc/layers/quantization/gptq_qy.cpp +++ b/csrc/layers/quantization/gptq_qy.cpp @@ -48,6 +48,25 @@ infinicore::Tensor GPTQ_QY::forward( return output; } +void GPTQ_QY::forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float /*alpha*/) const { + auto input_contiguous = input->is_contiguous() ? input : input->contiguous(); + auto qweight = params.at("qweight"); + auto qzeros = params.at("qzeros"); + auto scales = params.at("scales"); + + infinicore::op::linear_w4a16_gptq_qy_(output, input_contiguous->contiguous(), qweight, scales, qzeros, 0, 4); + + if (has_bias) { + auto bias = params.at("bias"); + infinicore::op::add_(output, output, bias->as_strided(output->shape(), {0, 0, 1})); + } +} + std::vector GPTQ_QY::split_params( const std::unordered_map ¶ms, const std::vector &splits, diff --git a/csrc/layers/quantization/gptq_qy.hpp b/csrc/layers/quantization/gptq_qy.hpp index 634b4aaf7..22635fb4d 100644 --- a/csrc/layers/quantization/gptq_qy.hpp +++ b/csrc/layers/quantization/gptq_qy.hpp @@ -112,6 +112,13 @@ class GPTQ_QY : public BaseQuantization { bool has_bias, float alpha = 1.0f) const override; + void forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float alpha = 1.0f) const override; + // Split fused linear parameters into named sub-parameters std::vector split_params( const std::unordered_map ¶ms, diff --git a/csrc/layers/quantization/none_quantization.cpp b/csrc/layers/quantization/none_quantization.cpp index 6f49a3943..2abf8dcf5 100644 --- a/csrc/layers/quantization/none_quantization.cpp +++ b/csrc/layers/quantization/none_quantization.cpp @@ -14,8 +14,7 @@ std::vector NoneQuantization::get_param_layout( std::vector descs; descs.push_back({"weight", {out_features, in_features}, dtype, split_dim, tp_rank, tp_size}); if (bias) { - descs.push_back({"bias", {out_features}, dtype, split_dim >= 0 ? 0 : -1, - split_dim >= 0 ? tp_rank : 0, split_dim >= 0 ? tp_size : 1}); + descs.push_back({"bias", {out_features}, dtype, split_dim >= 0 ? 0 : -1, split_dim >= 0 ? tp_rank : 0, split_dim >= 0 ? tp_size : 1}); } return descs; } @@ -37,6 +36,24 @@ infinicore::Tensor NoneQuantization::forward( return infinicore::op::linear(input_contiguous->contiguous(), weight->contiguous(), bias_opt, alpha); } +void NoneQuantization::forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float alpha) const { + + auto input_contiguous = input->is_contiguous() ? input : input->contiguous(); + auto weight = params.at("weight"); + + std::optional bias_opt; + if (has_bias) { + bias_opt = params.at("bias"); + } + + infinicore::op::linear_(output, input_contiguous->contiguous(), weight->contiguous(), bias_opt, alpha); +} + std::vector NoneQuantization::split_params( const std::unordered_map ¶ms, const std::vector &splits, diff --git a/csrc/layers/quantization/none_quantization.hpp b/csrc/layers/quantization/none_quantization.hpp index 44fd890d9..80fad1f4e 100644 --- a/csrc/layers/quantization/none_quantization.hpp +++ b/csrc/layers/quantization/none_quantization.hpp @@ -6,7 +6,7 @@ namespace infinilm::quantization { class NoneQuantization : public BaseQuantization { public: explicit NoneQuantization(const nlohmann::json &quant_config) - : BaseQuantization(quant_config) {}; + : BaseQuantization(quant_config){}; QuantScheme get_quant_scheme() const override { return QuantScheme::NONE; @@ -25,6 +25,13 @@ class NoneQuantization : public BaseQuantization { bool has_bias, float alpha = 1.0f) const override; + void forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float alpha = 1.0f) const override; + std::vector split_params( const std::unordered_map ¶ms, const std::vector &splits, diff --git a/csrc/models/glm4/glm4_attention.cpp b/csrc/models/glm4/glm4_attention.cpp index c5851cc26..579b0e89f 100644 --- a/csrc/models/glm4/glm4_attention.cpp +++ b/csrc/models/glm4/glm4_attention.cpp @@ -57,7 +57,7 @@ Glm4Attention::Glm4Attention(std::shared_ptr mode attn_ = std::make_shared( num_attention_heads_, head_dim_, scaling_, num_key_value_heads_, layer_idx_, - kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_); + kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_, device); // KV Cache quantization scale initialization infinilm::layers::attention::init_kv_cache_quant_params(register_fn, device, kv_cache_k_scale_, kv_cache_v_scale_); diff --git a/csrc/models/minicpm_sala/minicpm_sala_attention.cpp b/csrc/models/minicpm_sala/minicpm_sala_attention.cpp index 60fe4ee47..0413f274b 100644 --- a/csrc/models/minicpm_sala/minicpm_sala_attention.cpp +++ b/csrc/models/minicpm_sala/minicpm_sala_attention.cpp @@ -49,7 +49,7 @@ AttentionBase::AttentionBase(std::shared_ptr mode float scaling = 1.0f / std::sqrt(static_cast(head_dim_)); attn_ = std::make_shared(num_attention_heads_, head_dim_, scaling, num_key_value_heads_, layer_idx_, - kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_); + kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_, device); infinilm::layers::attention::init_kv_cache_quant_params([this](const std::string &n, infinicore::nn::Parameter p) { this->register_parameter(n, std::move(p)); }, device, kv_cache_k_scale_, kv_cache_v_scale_); diff --git a/csrc/models/qwen3/qwen3_attention.cpp b/csrc/models/qwen3/qwen3_attention.cpp index bff9c8d73..676d2567d 100644 --- a/csrc/models/qwen3/qwen3_attention.cpp +++ b/csrc/models/qwen3/qwen3_attention.cpp @@ -2,17 +2,19 @@ #include "../../global_state/global_state.hpp" #include "../../layers/attention/attention.hpp" #include "../../utils.hpp" +#include namespace infinilm::models::qwen3 { Qwen3Attention::Qwen3Attention(std::shared_ptr model_config, size_t layer_idx, - const infinicore::Device &device) { + const infinicore::Device &device) + : device_(device), + dtype_(model_config->get_dtype()) { layer_idx_ = layer_idx; hidden_size_ = model_config->get("hidden_size"); head_dim_ = model_config->get("head_dim"); - const auto &dtype{model_config->get_dtype()}; size_t total_num_heads = model_config->get("num_attention_heads"); size_t total_num_kv_heads = model_config->get("num_key_value_heads"); bool use_bias = model_config->get_or("attention_bias", true); @@ -35,21 +37,24 @@ Qwen3Attention::Qwen3Attention(std::shared_ptr mo qkv_proj_ = std::make_shared( hidden_size_, head_dim_, total_num_heads, total_num_kv_heads, "q_proj", "k_proj", "v_proj", register_fn, - quantization_method, use_bias, dtype, device, rank_info); + quantization_method, use_bias, dtype_, device_, rank_info); o_proj_ = this->register_module( "o_proj", total_num_heads * head_dim_, hidden_size_, quantization_method, - use_output_bias, dtype, device, tp_rank, tp_size, rank_info.comm); + use_output_bias, dtype_, device_, tp_rank, tp_size, rank_info.comm); - rotary_emb_ = infinilm::layers::rotary_embedding::get_rope(model_config, device); + rotary_emb_ = infinilm::layers::rotary_embedding::get_rope(model_config, device_); float scaling = 1.0f / std::sqrt(static_cast(head_dim_)); attn_ = std::make_shared(num_attention_heads_, head_dim_, scaling, num_key_value_heads_, layer_idx_, - kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_); + kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_, device_); - INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, rms_norm_eps, dtype, device); - INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, rms_norm_eps, dtype, device); + INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, rms_norm_eps, dtype_, device_); + INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, rms_norm_eps, dtype_, device_); - infinilm::layers::attention::init_kv_cache_quant_params(register_fn, device, kv_cache_k_scale_, kv_cache_v_scale_); + infinilm::layers::attention::init_kv_cache_quant_params(register_fn, device_, kv_cache_k_scale_, kv_cache_v_scale_); + + rank_qkv_output_size_ = qkv_proj_->out_features() / static_cast(tp_size); + this->_initialize_preallocated_workspace(); } infinicore::Tensor Qwen3Attention::forward(const infinicore::Tensor &positions, @@ -70,7 +75,8 @@ infinicore::Tensor Qwen3Attention::forward_static_(const infinicore::Tensor &pos size_t seq_len = shape[1]; // 1. Project Q, K, V - auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); + auto qkv_output = max_qkv_output_->narrow({{0, 0, batch_size * seq_len}})->view({batch_size, seq_len, rank_qkv_output_size_}); + auto [q, k, v] = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable); q = q_norm_->forward(q->view({batch_size * seq_len, num_attention_heads_, head_dim_})); k = k_norm_->forward(k->view({batch_size * seq_len, num_key_value_heads_, head_dim_})); @@ -100,9 +106,10 @@ infinicore::Tensor Qwen3Attention::forward_static_(const infinicore::Tensor &pos // 6. Attn Backend calculate auto attn_output = attn_->forward(q_rope, k_reshaped, v_reshaped); - // 7. Project output - auto output = o_proj_->forward(attn_output); - return output; + // 6. Project output + auto o_output = max_o_output_->narrow({{0, 0, batch_size * seq_len}})->view({batch_size, seq_len, hidden_size_}); + o_proj_->forward_(o_output, attn_output); + return o_output; } infinicore::Tensor Qwen3Attention::forward_paged_(const infinicore::Tensor &position_ids, @@ -118,7 +125,8 @@ infinicore::Tensor Qwen3Attention::forward_paged_(const infinicore::Tensor &posi ASSERT_EQ(batch_size, 1); // 1. Project Q, K, V - auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); + auto qkv_output = max_qkv_output_->narrow({{0, 0, seq_len}})->view({1, seq_len, rank_qkv_output_size_}); + auto [q, k, v] = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable); // 2. Reshape for multi-head attention auto q_reshaped = q->view({seq_len, num_attention_heads_, head_dim_}); @@ -147,6 +155,35 @@ infinicore::Tensor Qwen3Attention::forward_paged_(const infinicore::Tensor &posi auto attn_output = attn_->forward(q_reshaped, k_reshaped, v_reshaped); // 6. Project output - return o_proj_->forward(attn_output); + auto o_output = max_o_output_->narrow({{0, 0, seq_len}})->view({1, seq_len, hidden_size_}); + o_proj_->forward_(o_output, attn_output); + return o_output; +} + +void Qwen3Attention::_initialize_preallocated_workspace() { + const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); + auto &preallocated_workspace = infinilm::global_state::get_forward_context().preallocated_workspace; + const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; + + const std::string attention_cache_key = std::string("Qwen3Attention_max_num_batched_tokens_") + + std::to_string(max_num_batched_tokens) + "_rank_qkv_output_size_" + + std::to_string(rank_qkv_output_size_) + "_hidden_size_" + + std::to_string(hidden_size_) + "_dtype_" + + infinicore::toString(dtype_) + "_device_" + + device_.toString(); + + size_t max_output_size = std::max(rank_qkv_output_size_, hidden_size_); + if (preallocated_workspace.find(attention_cache_key) == preallocated_workspace.end()) { + auto attention_buffer = infinicore::Tensor::empty({max_num_batched_tokens * max_output_size}, dtype_, device_); + preallocated_workspace[attention_cache_key] = attention_buffer; + } + + auto attention_buffer = preallocated_workspace.at(attention_cache_key); + const auto attention_buffer_shape = attention_buffer->shape(); + ASSERT(attention_buffer_shape[0] == max_num_batched_tokens * max_output_size); + + max_qkv_output_ = attention_buffer->narrow({{0, 0, max_num_batched_tokens * rank_qkv_output_size_}})->view({max_num_batched_tokens, rank_qkv_output_size_}); + max_o_output_ = attention_buffer->narrow({{0, 0, max_num_batched_tokens * hidden_size_}})->view({max_num_batched_tokens, hidden_size_}); } + } // namespace infinilm::models::qwen3 diff --git a/csrc/models/qwen3/qwen3_attention.hpp b/csrc/models/qwen3/qwen3_attention.hpp index 44b69f386..0133a9339 100644 --- a/csrc/models/qwen3/qwen3_attention.hpp +++ b/csrc/models/qwen3/qwen3_attention.hpp @@ -25,6 +25,8 @@ class Qwen3Attention : public infinicore::nn::Module { infinicore::Tensor forward_paged_(const infinicore::Tensor &positions, const infinicore::Tensor &hidden_states) const; + void _initialize_preallocated_workspace(); + protected: std::shared_ptr qkv_proj_; std::shared_ptr o_proj_; @@ -39,9 +41,17 @@ class Qwen3Attention : public infinicore::nn::Module { size_t num_key_value_heads_; size_t hidden_size_; size_t head_dim_; + infinicore::Device device_; + infinicore::DataType dtype_; // For off-line kv cache quantization INFINICORE_NN_PARAMETER(kv_cache_k_scale); INFINICORE_NN_PARAMETER(kv_cache_v_scale); + + size_t rank_qkv_output_size_; + + // preallocated workspace for Attention + infinicore::Tensor max_qkv_output_; + infinicore::Tensor max_o_output_; }; } // namespace infinilm::models::qwen3 diff --git a/csrc/models/qwen3_next/qwen3_next_attention.cpp b/csrc/models/qwen3_next/qwen3_next_attention.cpp index 67fd38082..5cf469d5a 100644 --- a/csrc/models/qwen3_next/qwen3_next_attention.cpp +++ b/csrc/models/qwen3_next/qwen3_next_attention.cpp @@ -49,7 +49,7 @@ Qwen3NextAttention::Qwen3NextAttention(std::shared_ptr(head_dim_)); attn_ = std::make_shared(num_attention_heads_, head_dim_, scaling, num_key_value_heads_, layer_idx_, - kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_); + kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_, device); INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, rms_norm_eps, dtype, device); INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, rms_norm_eps, dtype, device); diff --git a/csrc/pybind11/engine/engine.hpp b/csrc/pybind11/engine/engine.hpp index 0d480bbff..6284f6c5f 100644 --- a/csrc/pybind11/engine/engine.hpp +++ b/csrc/pybind11/engine/engine.hpp @@ -1,5 +1,6 @@ #include "../../engine/infer_engine.hpp" #include "infinicore/tensor.hpp" +#include #include #include @@ -38,7 +39,8 @@ inline void bind_infer_engine(py::module &m) { std::shared_ptr cache_cfg, bool enable_graph_compiling, const std::string &attention_backend, - std::optional kv_cache_dtype) { + std::optional kv_cache_dtype, + size_t max_num_batched_tokens) { return std::make_shared( model_path, dist, @@ -46,7 +48,8 @@ inline void bind_infer_engine(py::module &m) { cache_cfg ? cache_cfg.get() : nullptr, enable_graph_compiling, infinilm::backends::parse_attention_backend(attention_backend), - kv_cache_dtype); + kv_cache_dtype, + max_num_batched_tokens); }), py::arg("model_path") = "", py::arg("distributed_config") = distributed::DistConfig(), @@ -54,7 +57,8 @@ inline void bind_infer_engine(py::module &m) { py::arg("cache_config") = py::none(), py::arg("enable_graph_compiling") = false, py::arg("attention_backend") = "default", - py::arg("kv_cache_dtype") = py::none()) + py::arg("kv_cache_dtype") = py::none(), + py::arg("max_num_batched_tokens") = 2048) .def("load_param", &InferEngine::load_param, py::arg("name"), py::arg("param"), "Load a parameter tensor into all workers (each worker picks its shard)") @@ -70,11 +74,20 @@ inline void bind_infer_engine(py::module &m) { return state_dict_tp_all; }) .def("process_weights_after_loading", &InferEngine::process_weights_after_loading, "Process the weights after loading on all workers (e.g., for quantization)") - .def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") + .def("forward", + [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { + // IMPORTANT: Release the GIL before calling forward() to allow other Python threads + // to run concurrently during inference (which may block for a long time). + // Do NOT remove this — without it, the GIL is held throughout inference and will + // deadlock or stall any other Python thread (e.g., request handling, scheduling). + py::gil_scoped_release release; + return self.forward(input); + }, + "Run inference on all ranks with arbitrary arguments") .def("reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) .def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr { - auto cfg = self.get_cache_config(); - return cfg ? std::shared_ptr(cfg->unique_copy()) : nullptr; }) + auto cfg = self.get_cache_config(); + return cfg ? std::shared_ptr(cfg->unique_copy()) : nullptr; }) .def("__repr__", [](const InferEngine &self) { return ""; }); py::class_(infer_engine, "Input") diff --git a/examples/bench.py b/examples/bench.py index 6672d6d7d..29d15211c 100644 --- a/examples/bench.py +++ b/examples/bench.py @@ -169,6 +169,7 @@ def __init__( cache_config=None, enable_graph=False, attn_backend="default", + max_num_batched_tokens: int = None, ) -> None: model_path = os.path.expanduser(model_path) # ---------------------------------------------------------------------------- # @@ -182,6 +183,7 @@ def __init__( enable_graph_compiling=enable_graph, attention_backend=attn_backend, kv_cache_dtype=cfg.kv_cache_dtype, + max_num_batched_tokens=max_num_batched_tokens, ) # ---------------------------------------------------------------------------- # @@ -281,6 +283,7 @@ def run( enable_paged_attn = cfg.enable_paged_attn enable_graph = cfg.enable_graph attn_backend = cfg.attn + max_num_batched_tokens = cfg.max_num_batched_tokens if isinstance(batch_size, int): batch_size = [batch_size] @@ -322,6 +325,7 @@ def run( cache_config=cache_config, enable_graph=enable_graph, attn_backend=attn_backend, + max_num_batched_tokens=max_num_batched_tokens, ) # ---------------------------------------------------------------------------- # diff --git a/python/infinilm/base_config.py b/python/infinilm/base_config.py index aab5dd459..c0fd306b7 100644 --- a/python/infinilm/base_config.py +++ b/python/infinilm/base_config.py @@ -44,7 +44,6 @@ class BaseConfig: """InfiniLM Unified Config - Command line argument parser""" def __init__(self): - self.parser = argparse.ArgumentParser(description="InfiniLM Unified Config") self._add_common_args() self.args, self.extra = self.parser.parse_known_args() @@ -70,6 +69,7 @@ def __init__(self): self.batch_size = self.args.batch_size self.max_batch_size = self.args.max_batch_size + self.max_num_batched_tokens = self.args.max_num_batched_tokens self.input_len = self.args.input_len self.output_len = self.args.output_len self.max_new_tokens = self.args.max_new_tokens @@ -155,6 +155,12 @@ def _add_common_args(self): default=8, help="maximum batch size for server", ) + self.parser.add_argument( + "--max-num-batched-tokens", + type=int, + default=None, + help="maximum number of batched tokens for paged attention", + ) self.parser.add_argument( "--input-len", type=parse_list, default=10, help="input sequence length" ) diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index 31d65de93..652ab9d89 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -24,6 +24,7 @@ def read_hf_config(model_path): ) return config_dict + # config.json (required) defines model architecture, while generation_config.json # (optional) defines generation behavior. They are kept as separate readers # because: 1) config.json must exist and requires model_type validation, @@ -37,6 +38,7 @@ def read_hf_generation_config(model_path): return json.load(f) return {} + @dataclass class GenerationConfig: max_new_tokens: int | None = None @@ -59,6 +61,7 @@ def __init__( enable_graph_compiling=False, attention_backend="default", kv_cache_dtype=None, + max_num_batched_tokens: int | None = None, ): self.hf_config = read_hf_config(model_path) self.hf_generation_config = read_hf_generation_config(model_path) @@ -66,6 +69,12 @@ def __init__( if device is None: device = infinicore.device() + max_position_embeddings = self.hf_config["max_position_embeddings"] + if max_num_batched_tokens is None: + max_num_batched_tokens = max_position_embeddings + assert 512 <= max_num_batched_tokens <= max_position_embeddings + self.max_num_batched_tokens = max_num_batched_tokens + hf_config_str = json.dumps(self.hf_config) super().__init__( hf_config_str, @@ -79,6 +88,7 @@ def __init__( if kv_cache_dtype is not None else None ), + max_num_batched_tokens, ) self.use_cache = False @@ -364,6 +374,6 @@ def state_dict_keyname(self): def load_state_dict(self, state_dict, strict=None): for name, param in state_dict.items(): super().load_param(name, param._underlying) - + def process_weights_after_loading(self): super().process_weights_after_loading() diff --git a/python/infinilm/llm/llm.py b/python/infinilm/llm/llm.py index cba3af83a..69f640c1a 100644 --- a/python/infinilm/llm/llm.py +++ b/python/infinilm/llm/llm.py @@ -6,6 +6,7 @@ - AsyncLLM class for asynchronous streaming (server use) """ +import os import asyncio import time import uuid @@ -74,6 +75,7 @@ class EngineConfig: enable_graph: bool = False attn_backend: str = "default" skip_load: bool = False + max_num_batched_tokens: int | None = None class LLMEngine: @@ -92,7 +94,9 @@ def __init__(self, config: EngineConfig): distributed_config=DistConfig(config.tensor_parallel_size), enable_graph_compiling=config.enable_graph, attention_backend=config.attn_backend, + max_num_batched_tokens=config.max_num_batched_tokens, ) + self.max_num_batched_tokens = self.model_engine.max_num_batched_tokens # Load model weights if not self.config.skip_load: @@ -117,10 +121,12 @@ def __init__(self, config: EngineConfig): cache_config = PagedKVCacheConfig( num_blocks=config.num_blocks, block_size=config.block_size ) + self.scheduler = Scheduler( max_batch_size=config.max_batch_size, num_blocks=config.num_blocks, block_size=config.block_size, + max_num_batched_tokens=self.max_num_batched_tokens, ) logger.info(f"Using Paged KV Cache with num_blocks={config.num_blocks}") else: diff --git a/python/infinilm/llm/scheduler.py b/python/infinilm/llm/scheduler.py index f9c11635a..daeccbc38 100644 --- a/python/infinilm/llm/scheduler.py +++ b/python/infinilm/llm/scheduler.py @@ -39,6 +39,7 @@ def __init__( max_batch_size: int = 16, num_blocks: int = 512, block_size: int = 256, + max_num_batched_tokens: int = 1024, ): self.waiting_queue = janus.Queue() self.running_queue = janus.Queue() @@ -47,6 +48,8 @@ def __init__( self.cache_manager = BlockManager(num_blocks=num_blocks, block_size=block_size) self.block_size = block_size + self.max_num_batched_tokens = max_num_batched_tokens + def add_request(self, request: InferenceRequest): if request is not None: request.status = RequestStatus.WAITING @@ -56,9 +59,13 @@ def schedule(self) -> Optional[SchedulerOutput]: """Schedule and return batch of requests to execute.""" scheduled_requests = [] is_prefill = False + current_num_batched_tokens = 0 # Process Waiting queue (prefill phase) - while len(scheduled_requests) < self.max_batch_size: + while ( + len(scheduled_requests) < self.max_batch_size + and current_num_batched_tokens < self.max_num_batched_tokens + ): try: req = self.waiting_queue.sync_q.get_nowait() except queue.Empty: @@ -89,6 +96,23 @@ def schedule(self) -> Optional[SchedulerOutput]: self.cache_manager.allocate_blocks(req_tokens, req.block_table) ) + num_tokens_this_step = ( + req.get_prompt_length() - req.num_cached_tokens + ) + if ( + current_num_batched_tokens + num_tokens_this_step + >= self.max_num_batched_tokens + ): + if req.num_cached_tokens > 0: + self.cache_manager.free_blocks(req.block_table) + req.block_table = [] + req.slot_mapping = [] + req.num_cached_tokens = 0 + + self.waiting_queue.sync_q.put(req) + break + + current_num_batched_tokens += num_tokens_this_step req.num_blocks = len(req.block_table) req.status = RequestStatus.RUNNING scheduled_requests.append(req) diff --git a/test/bench/test_benchmark.py b/test/bench/test_benchmark.py index c15c950fe..476e45de5 100644 --- a/test/bench/test_benchmark.py +++ b/test/bench/test_benchmark.py @@ -55,6 +55,7 @@ def __init__( enable_paged_attn=False, enable_graph=False, attn_backend="default", + max_num_batched_tokens: int | None = None, ): import transformers import infinicore @@ -119,6 +120,7 @@ def __init__( ), enable_graph_compiling=enable_graph, attention_backend=attn_backend, + max_num_batched_tokens=max_num_batched_tokens, ) # Enable KV cache for generation @@ -1125,6 +1127,7 @@ def main(): cfg.bench, cfg.enable_paged_attn, cfg.enable_graph, + cfg.max_num_batched_tokens, cfg.attn, )