[WIP] feat: support MLA and refactor MHA#163
Open
Chamberlain0w0 wants to merge 4 commits into
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
修改目前 MHA 实现
a. 原来的
TransformerConfig::attention_type = kStandard / kRoPE不太合适,Megatron 及其他开源实现中通常把 attn_type 分为 self/cross。这块命名更改为 Megatron 中使用的--position-embedding-type,可选值为learned_absolute / rope / yarn / mrope / relative / none。相应地修改创建 WPE/apply rope 的相关条件判断。b. 删除了
CausalSelfAttention::ForwardStandard和ForwardWithRoPE两条分支,合并成一个统一的Forward。GQA 也被纳入统一路径。c.
ApplyRotaryEmbedding从CausalSelfAttention成员函数提到了 transformer utils.ccd. causal mask buffer 现在无论 learned absolute 还是 RoPE 都会初始化;如果外部没有传 mask,会 fallback 到内部 causal mask。这个对 RoPE 直接调用且不传 mask 的场景是一个小的行为统一。
添加 MLA Module
--TODO--