33from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM
44from .modeling_mistral_kv import MistralForCausalLM as KVMistralForCausalLM
55# import transformers
6-
6+ import pdb
77# # monkey patch
88# transformers.models.llama.modeling_llama.LlamaForCausalLM = KVLlamaForCausalLM
99# transformers.models.mistral.modeling_mistral.MistralForCausalLM = KVMistralForCausalLM
@@ -121,6 +121,7 @@ def __init__(
121121 @property
122122 def base_model (self ):
123123 return self
124+
124125 @classmethod
125126 def from_pretrained (
126127 cls ,
@@ -219,6 +220,7 @@ def forward(
219220 if output_orig :
220221 return torch .stack (medusa_logits , dim = 0 ), outputs , orig
221222 return torch .stack (medusa_logits , dim = 0 )
223+
222224 def get_medusa_choice (self , model_name ):
223225 if 'vicuna' in model_name :
224226 if '7b' in model_name :
@@ -264,10 +266,11 @@ def medusa_generate(
264266
265267 Warning: Only support batch size 1 for now!!
266268 """
267- assert input_ids .shape [0 ] == 1 , "Only support batch size 1 for now!!"
269+ # assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
270+ batch_size = input_ids .shape [0 ]
271+ valid_length = attention_mask .sum (dim = 1 )
268272 # Avoid modifying the input_ids in-place
269273 input_ids = input_ids .clone ()
270-
271274 # Cache medusa buffers (the fixed patterns for tree attention)
272275 if medusa_choices is None :
273276 medusa_choices = self .get_medusa_choice (self .base_model_name_or_path )
@@ -295,7 +298,7 @@ def medusa_generate(
295298 past_key_values ,
296299 past_key_values_data ,
297300 current_length_data ,
298- ) = initialize_past_key_values (self .base_model )
301+ ) = initialize_past_key_values (self .base_model , batch_size )
299302 self .past_key_values = past_key_values
300303 self .past_key_values_data = past_key_values_data
301304 self .current_length_data = current_length_data
@@ -305,12 +308,11 @@ def medusa_generate(
305308 reset_medusa_mode (self )
306309 # Initialize tree attention mask and process prefill tokens
307310 medusa_logits , logits = initialize_medusa (
308- input_ids , self , medusa_buffers ["medusa_attn_mask" ], past_key_values
311+ input_ids , self , medusa_buffers ["medusa_attn_mask" ], past_key_values , attention_mask
309312 )
310-
311313 new_token = 0
312314 last_round_token = 0
313-
315+ ends = [ input_len ] * batch_size
314316 for idx in range (max_steps ):
315317 # Generate candidates with topk predictions from Medusa heads
316318 candidates , tree_candidates = generate_candidates (
@@ -324,8 +326,8 @@ def medusa_generate(
324326 top_p = top_p ,
325327 sampling = sampling ,
326328 fast = fast ,
329+ valid_length = valid_length
327330 )
328-
329331 # Use tree attention to verify the candidates and get predictions
330332 medusa_logits , logits , outputs = tree_decoding (
331333 self ,
@@ -334,15 +336,14 @@ def medusa_generate(
334336 medusa_buffers ["medusa_position_ids" ],
335337 input_ids ,
336338 medusa_buffers ["retrieve_indices" ],
339+ attention_mask = attention_mask
337340 )
338-
339341 # Evaluate the posterior of the candidates to select the accepted candidate prefix
340342 best_candidate , accept_length = evaluate_posterior (
341343 logits , candidates , temperature , posterior_threshold , posterior_alpha , top_p = top_p , sampling = sampling , fast = fast
342344 )
343-
344345 # Update the input_ids and logits
345- input_ids , logits , medusa_logits , new_token = update_inference_inputs (
346+ input_ids , logits , medusa_logits , new_token , valid_length , attention_mask = update_inference_inputs (
346347 input_ids ,
347348 candidates ,
348349 best_candidate ,
@@ -354,18 +355,29 @@ def medusa_generate(
354355 new_token ,
355356 past_key_values_data ,
356357 current_length_data ,
358+ attention_mask = attention_mask ,
359+ padding_idx = self .tokenizer .pad_token_id
357360 )
358361
359- yield {
360- "text" : self .tokenizer .decode (
361- input_ids [0 , input_len :],
362+ decoded_texts = []
363+ eos_encountered = [False ] * batch_size
364+ for i in range (batch_size ):
365+ # 检查当前批次的文本是否包含结束符
366+ if self .tokenizer .eos_token_id in input_ids [i , input_len :]:
367+ eos_encountered [i ] = True
368+ else :
369+ ends [i ] = len (input_ids [i ])
370+ decoded_text = self .tokenizer .decode (
371+ input_ids [i , input_len :ends [i ]],
362372 skip_special_tokens = True ,
363373 spaces_between_special_tokens = False ,
364374 clean_up_tokenization_spaces = True ,
365375 )
366- }
376+ decoded_texts .append (decoded_text )
377+ yield { "text" : decoded_texts }
367378
368- if self .tokenizer .eos_token_id in input_ids [0 , input_len :]:
379+ # 如果所有批次都遇到了 EOS,则停止
380+ if all (eos_encountered ):
369381 break
370382
371383
0 commit comments