|
return splitk_sparse_gemv(hidden_states, weights, threshold, sparsity_bin) if hidden_states.shape[1] == 1 else torch.matmul(hidden_states, weights.T) |
Hi, I notice that the SparseGEMV kernel only manage the case when batch_size=1 & seqlen=1. Beyond that case, the kernel outputs wrong answer.
Is it expected that this kernel only work for decoding stage? Then where is the implementation about Appendix A4?
TEAL/kernels/sparse_gemv.py
Line 271 in fb7373c
Hi, I notice that the SparseGEMV kernel only manage the case when
batch_size=1 & seqlen=1. Beyond that case, the kernel outputs wrong answer.Is it expected that this kernel only work for decoding stage? Then where is the implementation about Appendix A4?