@@ -1856,7 +1856,8 @@ def emit(
18561856 suppressed_text = None ,
18571857 suppressed_tokens = None ,
18581858 stop_token : int = None ,
1859- stop_string : str = None
1859+ stop_string : str = None ,
1860+ rem_held_text : str = None
18601861 ):
18611862 r = {
18621863 "job" : self ,
@@ -1919,18 +1920,29 @@ def emit(
19191920 "accepted_draft_tokens" : self .accepted_draft_tokens ,
19201921 "rejected_draft_tokens" : self .rejected_draft_tokens
19211922 })
1923+ if eos_reason == "stop_string" :
1924+ self .held_text = rem_held_text
1925+ rh = {}
1926+ if self .held_text :
1927+ rh .update ({ "text" : self .held_text })
1928+ if self .held_tokens :
1929+ rh .update ({ "token_ids" : self .held_tokens .torch ().clone () })
1930+ if self .held_probs :
1931+ rh .update ({ "token_probs" : self .held_probs .torch ().clone () })
1932+ if self .held_k_tokens :
1933+ rh .update ({ "top_k_tokens" : self .held_k_tokens .torch ().clone () })
1934+ rh .update ({ "top_k_probs" : self .held_k_probs .torch ().clone () })
1935+ if self .held_logits :
1936+ rh .update ({ "logits" : self .held_logits .torch ().clone () })
1937+ if rh :
1938+ r .update ({ "held" : rh })
19221939
19231940 if self .identifier is not None :
19241941 r .update ({ "identifier" : self .identifier })
19251942
19261943 results .append (r )
19271944 return emit_eos , next_token
19281945
1929- # End on stop tokens
1930-
1931- if next_token .item () in self .stop_tokens :
1932- return emit (results , emit_eos = True , eos_reason = "stop_token" , stop_token = next_token .item ())
1933-
19341946 # Decode and buffer output
19351947
19361948 id_to_piece = self .generator .tokenizer .get_id_to_piece_list (self .decode_special_tokens )
@@ -1950,6 +1962,11 @@ def emit(
19501962 if self .return_logits :
19511963 self .held_logits .append (logits [:1 , :, :])
19521964
1965+ # End on stop tokens
1966+
1967+ if next_token .item () in self .stop_tokens :
1968+ return emit (results , emit_eos = True , eos_reason = "stop_token" , stop_token = next_token .item ())
1969+
19531970 # Stop if we reach max_new_tokens
19541971
19551972 if self .new_tokens >= self .max_new_tokens - self .generator .num_draft_tokens :
@@ -2052,7 +2069,14 @@ def rewind_checkpoint():
20522069 self .held_text = self .held_text [:match ]
20532070 for s in self .stop_strings :
20542071 if held .startswith (s ):
2055- return emit (results , emit_eos = True , emit_held = True , eos_reason = "stop_string" , stop_string = s )
2072+ return emit (
2073+ results ,
2074+ emit_eos = True ,
2075+ emit_held = True ,
2076+ eos_reason = "stop_string" ,
2077+ stop_string = s ,
2078+ rem_held_text = held
2079+ )
20562080 assert False , "Detected stop string but couldn't identify it (logic error)"
20572081 if match == - 2 :
20582082 return emit (results )
0 commit comments