2121import itertools
2222from dataclasses import dataclass
2323# import xxhash
24- # from line_profiler import profile
24+ from line_profiler import profile
2525
2626# TODO:
2727# - ExLlamaV2StreamingGenerator wrapper
@@ -893,6 +893,11 @@ def iterate(self) -> list[dict]:
893893 "stop_string"
894894 "max_new_tokens"
895895 "end_filter"
896+ optional, if "eos_reason" == "stop_token":
897+ "eos_triggering_token_id": int
898+ "eos_triggering_token_str": str
899+ optional, if "eos_reason" == "stop_string":
900+ "eos_triggering_string": str
896901 "full_completion": str - full text completion
897902 "new_tokens": int - number of tokens generated
898903 "time_enqueued": float - time from job was enqueued until it started, in seconds
@@ -1849,7 +1854,9 @@ def emit(
18491854 eos_reason : str = None ,
18501855 emit_held = False ,
18511856 suppressed_text = None ,
1852- suppressed_tokens = None
1857+ suppressed_tokens = None ,
1858+ stop_token : int = None ,
1859+ stop_string : str = None
18531860 ):
18541861 r = {
18551862 "job" : self ,
@@ -1860,6 +1867,15 @@ def emit(
18601867
18611868 if eos_reason is not None :
18621869 r .update ({ "eos_reason" : eos_reason })
1870+ if eos_reason == "stop_token" :
1871+ id_to_piece = self .generator .tokenizer .get_id_to_piece_list (True )
1872+ r .update ({
1873+ "eos_triggering_token_id" : stop_token ,
1874+ "eos_triggering_token_str" : id_to_piece [stop_token ]
1875+ })
1876+ pass
1877+ if eos_reason == "stop_string" :
1878+ r .update ({ "eos_triggering_string" : stop_string })
18631879
18641880 if emit_held :
18651881 if self .held_text != "" :
@@ -1913,7 +1929,7 @@ def emit(
19131929 # End on stop tokens
19141930
19151931 if next_token .item () in self .stop_tokens :
1916- return emit (results , emit_eos = True , eos_reason = "stop_token" )
1932+ return emit (results , emit_eos = True , eos_reason = "stop_token" , stop_token = next_token . item () )
19171933
19181934 # Decode and buffer output
19191935
@@ -2032,8 +2048,12 @@ def rewind_checkpoint():
20322048 self .stop_strings_utf32_buffer
20332049 )
20342050 if match >= 0 :
2051+ held = self .held_text [match :]
20352052 self .held_text = self .held_text [:match ]
2036- return emit (results , emit_eos = True , emit_held = True , eos_reason = "stop_string" )
2053+ for s in self .stop_strings :
2054+ if held .startswith (s ):
2055+ return emit (results , emit_eos = True , emit_held = True , eos_reason = "stop_string" , stop_string = s )
2056+ assert False , "Detected stop string but couldn't identify it (logic error)"
20372057 if match == - 2 :
20382058 return emit (results )
20392059
0 commit comments