Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions csrc/layers/mlp/mlp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ class MLP : public infinicore::nn::Module {
size_t hidden_size() const { return hidden_size_; }
size_t intermediate_size() const { return intermediate_size_; }

infinicore::Tensor gate_up_weight() const { return gate_up_proj_->weight(); }
infinicore::Tensor down_weight() const { return down_proj_->weight(); }
float down_alpha() const { return down_proj_->alpha(); }

protected:
std::shared_ptr<infinilm::layers::linear::GateUpParallelLinear> gate_up_proj_;
std::shared_ptr<infinilm::layers::linear::RowParallelLinear> down_proj_;
Expand Down
3 changes: 2 additions & 1 deletion csrc/models/fm9g/fm9g_for_causal_lm.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "../../layers/common_modules.hpp"
#include "fm9g_fused_decoder_layer.hpp"
#include "infinicore/nn/linear.hpp"
#include <cmath>
#include <memory>
Expand Down Expand Up @@ -38,7 +39,7 @@ class FM9GMLP : public infinilm::layers::mlp::MLP {
}
};

using FM9GDecoderLayer = infinilm::layers::causal_lm_templates::TextDecoderLayer<FM9GAttention, FM9GMLP>;
using FM9GDecoderLayer = FM9GFusedDecoderLayer<FM9GAttention, FM9GMLP>;

using FM9GModel = infinilm::layers::causal_lm_templates::TextModel<FM9GDecoderLayer>;

Expand Down
108 changes: 108 additions & 0 deletions csrc/models/fm9g/fm9g_fused_decoder_layer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#pragma once

#include "../../layers/causal_lm_templates/text_decoder_layer.hpp"
#include "infinicore/ops/add.hpp"
#include "infinicore/ops/fused_ffn.hpp"

#include <cstdlib>
#include <string>
#include <tuple>

namespace infinilm::models::fm9g {

// FM9G decoder layer that may substitute the post-attention rms_norm + MLP
// block with the InfiniCore fused-FFN op.
//
// The substitution is gated per `forward()` call by `INFINILM_USE_FUSED_FFN`
// so benchmarks can interleave fused and non-fused passes within one process.
// When MuP scaling on `down_proj` is active (alpha != 1.0), the per-op path
// is taken to preserve the multiplier the fused kernel does not model.
//
// The fused kernel accepts only rank-2 `[ntok, hidden]` tensors, so the call
// site views `hidden_states` to 2-D and back.
template <typename Attention, typename MLP>
class FM9GFusedDecoderLayer
: public infinilm::layers::causal_lm_templates::TextDecoderLayer<Attention, MLP> {
using Base = infinilm::layers::causal_lm_templates::TextDecoderLayer<Attention, MLP>;

public:
FM9GFusedDecoderLayer(std::shared_ptr<infinilm::config::ModelConfig> model_config,
size_t layer_idx,
const infinicore::Device &device)
: Base(model_config, layer_idx, device),
rms_norm_eps_(static_cast<float>(model_config->get<double>("rms_norm_eps"))) {}

std::tuple<infinicore::Tensor, infinicore::Tensor>
forward(const infinicore::Tensor &positions,
infinicore::Tensor &hidden_states,
infinicore::Tensor &residual) {
this->input_layernorm_->forward_inplace(hidden_states, residual);
hidden_states = this->self_attn_->forward(positions, hidden_states);

if (use_fused_ffn()) {
residual = infinicore::op::add(residual, hidden_states);
auto fused_in_2d = as_2d(residual);
auto fused_out_2d = infinicore::op::fused_ffn(
fused_in_2d, std::nullopt,
this->post_attention_layernorm_->weight(),
this->mlp_->gate_up_weight(),
this->mlp_->down_weight(),
rms_norm_eps_);
hidden_states = fused_out_2d->view(residual->shape());
} else {
this->post_attention_layernorm_->forward_inplace(hidden_states, residual);
hidden_states = this->mlp_->forward(hidden_states);
}
return std::make_tuple(hidden_states, residual);
}

infinicore::Tensor forward(const infinicore::Tensor &positions,
infinicore::Tensor &hidden_states) {
auto residual = hidden_states;
hidden_states = this->input_layernorm_->forward(hidden_states);
hidden_states = this->self_attn_->forward(positions, hidden_states);
hidden_states = infinicore::op::add(residual, hidden_states);

if (use_fused_ffn()) {
const auto orig_shape = hidden_states->shape();
auto fused_in_2d = as_2d(hidden_states);
auto fused_residual_2d = as_2d(hidden_states);
auto fused_out_2d = infinicore::op::fused_ffn(
fused_in_2d, fused_residual_2d,
this->post_attention_layernorm_->weight(),
this->mlp_->gate_up_weight(),
this->mlp_->down_weight(),
rms_norm_eps_);
hidden_states = fused_out_2d->view(orig_shape);
} else {
residual = hidden_states;
hidden_states = this->post_attention_layernorm_->forward(hidden_states);
hidden_states = this->mlp_->forward(hidden_states);
hidden_states = infinicore::op::add(residual, hidden_states);
}
return hidden_states;
}

private:
bool use_fused_ffn() const {
if (this->mlp_->down_alpha() != 1.0f) {
return false;
}
const char *env = std::getenv("INFINILM_USE_FUSED_FFN");
return env != nullptr && std::string(env) == "1";
}

static infinicore::Tensor as_2d(const infinicore::Tensor &t) {
const auto &shape = t->shape();
size_t hidden = shape.back();
size_t ntok = 1;
for (size_t i = 0; i + 1 < shape.size(); ++i) {
ntok *= shape[i];
}
return t->view({ntok, hidden});
}

float rms_norm_eps_;
};

} // namespace infinilm::models::fm9g
Loading