Skip to content

Commit 71b4138

Browse files
Merge pull request #396 from AI-Hypercomputer:fix-ulysses-custom
PiperOrigin-RevId: 910099466
2 parents 3ef0fdd + f23c50c commit 71b4138

2 files changed

Lines changed: 4 additions & 5 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from . import quantizations
3636
from .modeling_flax_utils import get_activation
3737

38+
LOG2E = math.log2(math.e)
3839

3940
Array = common_types.Array
4041
Mesh = common_types.Mesh
@@ -591,9 +592,7 @@ def wrap_ulysses_attention(query, key, value):
591592
heads_per_tile = getattr(flash_block_sizes, "heads_per_tile", heads_per_tile)
592593

593594
if use_base2_exp:
594-
query_scaled = query * 1.44269504
595-
else:
596-
query_scaled = query
595+
query = query * LOG2E
597596

598597
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, bq)
599598
key, _, key_seq_len = _pad_data_for_flash(key, heads, bkv)
@@ -612,7 +611,7 @@ def wrap_ulysses_attention(query, key, value):
612611
)
613612

614613
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0))
615-
attention_output = vmapped_splash(query_scaled, key, value)
614+
attention_output = vmapped_splash(query, key, value)
616615
attention_output = jnp.swapaxes(attention_output, 2, 3)
617616
attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
618617
else:

src/maxdiffusion/pyconfig.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def user_init(raw_keys):
214214
# Verify qkv is sharded across sequence.
215215
attention = raw_keys["attention"]
216216
uses_ring_attention = "ring" in attention
217-
uses_ulysses_attention = attention == "ulysses"
217+
uses_ulysses_attention = "ulysses" in attention
218218
uses_uniform_sequence_sharding = raw_keys["attention_sharding_uniform"]
219219
if uses_ring_attention or uses_ulysses_attention or uses_uniform_sequence_sharding:
220220
max_logging.log(

0 commit comments

Comments
 (0)