3535from . import quantizations
3636from .modeling_flax_utils import get_activation
3737
38+ LOG2E = math .log2 (math .e )
3839
3940Array = common_types .Array
4041Mesh = common_types .Mesh
@@ -591,9 +592,7 @@ def wrap_ulysses_attention(query, key, value):
591592 heads_per_tile = getattr (flash_block_sizes , "heads_per_tile" , heads_per_tile )
592593
593594 if use_base2_exp :
594- query_scaled = query * 1.44269504
595- else :
596- query_scaled = query
595+ query = query * LOG2E
597596
598597 query , kv_size , query_seq_len = _pad_data_for_flash (query , heads , bq )
599598 key , _ , key_seq_len = _pad_data_for_flash (key , heads , bkv )
@@ -612,7 +611,7 @@ def wrap_ulysses_attention(query, key, value):
612611 )
613612
614613 vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 ))
615- attention_output = vmapped_splash (query_scaled , key , value )
614+ attention_output = vmapped_splash (query , key , value )
616615 attention_output = jnp .swapaxes (attention_output , 2 , 3 )
617616 attention_output = attention_output [:, :, :query_seq_len , :kv_size ].astype (query .dtype )
618617 else :
0 commit comments