1616from huggingface_hub import hf_hub_download , HfApi
1717from typing import Optional , List , Dict
1818from pathlib import Path
19+ import transformers
1920
2021from text_generation_server .utils .speculate import get_speculate , set_speculate
2122from text_generation_server .models .model import Model
2223from text_generation_server .models .causal_lm import CausalLM , CausalLMBatchKeysLast
24+
2325from text_generation_server .models .custom_modeling .opt_modeling import OPTForCausalLM
2426from text_generation_server .models .custom_modeling .mpt_modeling import (
2527 MPTForCausalLM ,
178180if MAMBA_AVAILABLE :
179181 __all__ .append (Mamba )
180182
183+ FLASH_TRANSFORMERS_BACKEND = True
184+ try :
185+ from text_generation_server .models .transformers_flash_causal_lm import (
186+ TransformersFlashCausalLM ,
187+ )
188+ except ImportError :
189+ FLASH_TRANSFORMERS_BACKEND = False
190+
181191
182192class ModelType (enum .Enum ):
183193 DEEPSEEK_V2 = {
@@ -381,6 +391,21 @@ def get_model(
381391 )
382392 model_type = config_dict .get ("model_type" , None )
383393
394+ transformers_causal_lm_class = CausalLM
395+
396+ # Fast transformers path
397+ transformers_model_class = getattr (
398+ transformers ,
399+ modeling_auto .MODEL_FOR_CAUSAL_LM_MAPPING_NAMES .get (model_type , "" ),
400+ None ,
401+ )
402+ if (
403+ FLASH_TRANSFORMERS_BACKEND
404+ and transformers_model_class is not None
405+ and transformers_model_class ._supports_flex_attn
406+ ):
407+ transformers_causal_lm_class = TransformersFlashCausalLM
408+
384409 quantization_config = config_dict .get ("quantization_config" , None )
385410 if quantization_config is None :
386411 quantization_config = config_dict .get ("compression_config" , None )
@@ -624,7 +649,7 @@ def get_model(
624649 FLASH_ATT_ERROR_MESSAGE .format ("Sharded Deepseek V2" )
625650 )
626651 else :
627- return CausalLM .fallback (
652+ return transformers_causal_lm_class .fallback (
628653 model_id ,
629654 revision ,
630655 quantize = quantize ,
@@ -683,7 +708,7 @@ def get_model(
683708 FLASH_ATT_ERROR_MESSAGE .format ("Sharded Santacoder" )
684709 )
685710 else :
686- return CausalLM .fallback (
711+ return transformers_causal_lm_class .fallback (
687712 model_id = model_id ,
688713 revision = revision ,
689714 quantize = quantize ,
@@ -731,7 +756,7 @@ def get_model(
731756 except RuntimeError as e :
732757 # Lots of legacy models with various weight names.
733758 log_master (logger .warning , f"Couldn't load flash gpt2 variant: { e } " )
734- return CausalLM .fallback (
759+ return transformers_causal_lm_class .fallback (
735760 model_id ,
736761 revision ,
737762 quantize = quantize ,
@@ -742,7 +767,7 @@ def get_model(
742767 elif sharded :
743768 raise NotImplementedError (FLASH_ATT_ERROR_MESSAGE .format ("Sharded GPT-2" ))
744769 else :
745- return CausalLM .fallback (
770+ return transformers_causal_lm_class .fallback (
746771 model_id ,
747772 revision ,
748773 quantize = quantize ,
@@ -767,7 +792,7 @@ def get_model(
767792 except RuntimeError as e :
768793 # Lots of legacy models with various weight names.
769794 log_master (logger .warning , f"Couldn't load flash gptj variant: { e } " )
770- return CausalLM .fallback (
795+ return transformers_causal_lm_class .fallback (
771796 model_id ,
772797 revision ,
773798 quantize = quantize ,
@@ -778,7 +803,7 @@ def get_model(
778803 elif sharded :
779804 raise NotImplementedError (FLASH_ATT_ERROR_MESSAGE .format ("Sharded GPT-J" ))
780805 else :
781- return CausalLM .fallback (
806+ return transformers_causal_lm_class .fallback (
782807 model_id ,
783808 revision ,
784809 quantize = quantize ,
@@ -815,7 +840,7 @@ def get_model(
815840 trust_remote_code = trust_remote_code ,
816841 )
817842 else :
818- return CausalLM .fallback (
843+ return transformers_causal_lm_class .fallback (
819844 model_id ,
820845 revision ,
821846 quantize = quantize ,
@@ -838,7 +863,7 @@ def get_model(
838863 lora_adapter_ids = lora_adapter_ids ,
839864 )
840865 else :
841- return CausalLM .fallback (
866+ return transformers_causal_lm_class .fallback (
842867 model_id ,
843868 revision ,
844869 quantize = quantize ,
@@ -862,7 +887,7 @@ def get_model(
862887 lora_adapter_ids = lora_adapter_ids ,
863888 )
864889 else :
865- return CausalLM .fallback (
890+ return transformers_causal_lm_class .fallback (
866891 model_id ,
867892 revision ,
868893 quantize = quantize ,
@@ -911,7 +936,7 @@ def get_model(
911936 FLASH_ATT_ERROR_MESSAGE .format (f"Sharded { model_type } " )
912937 )
913938 else :
914- return CausalLM .fallback (
939+ return transformers_causal_lm_class .fallback (
915940 model_id ,
916941 revision ,
917942 quantize = quantize ,
@@ -937,7 +962,7 @@ def get_model(
937962 elif sharded :
938963 raise NotImplementedError (FLASH_ATT_ERROR_MESSAGE .format ("Sharded Gemma" ))
939964 else :
940- return CausalLM .fallback (
965+ return transformers_causal_lm_class .fallback (
941966 model_id ,
942967 revision ,
943968 quantize = quantize ,
@@ -963,7 +988,7 @@ def get_model(
963988 elif sharded :
964989 raise NotImplementedError (FLASH_ATT_ERROR_MESSAGE .format ("Sharded Gemma2" ))
965990 else :
966- return CausalLM .fallback (
991+ return transformers_causal_lm_class .fallback (
967992 model_id ,
968993 revision ,
969994 quantize = quantize ,
@@ -988,7 +1013,7 @@ def get_model(
9881013 elif sharded :
9891014 raise NotImplementedError (FLASH_ATT_ERROR_MESSAGE .format ("Sharded Cohere" ))
9901015 else :
991- return CausalLM .fallback (
1016+ return transformers_causal_lm_class .fallback (
9921017 model_id ,
9931018 revision ,
9941019 quantize = quantize ,
@@ -1016,7 +1041,7 @@ def get_model(
10161041 elif sharded :
10171042 raise NotImplementedError (FLASH_ATT_ERROR_MESSAGE .format ("Sharded DBRX" ))
10181043 else :
1019- return CausalLM .fallback (
1044+ return transformers_causal_lm_class .fallback (
10201045 model_id ,
10211046 revision ,
10221047 quantize = quantize ,
@@ -1066,7 +1091,7 @@ def get_model(
10661091 config_class = RWConfig ,
10671092 )
10681093 else :
1069- return CausalLM .fallback (
1094+ return transformers_causal_lm_class .fallback (
10701095 model_id ,
10711096 revision ,
10721097 quantize = quantize ,
@@ -1091,7 +1116,7 @@ def get_model(
10911116 elif sharded :
10921117 raise NotImplementedError (FLASH_ATT_ERROR_MESSAGE .format ("Sharded Mistral" ))
10931118 else :
1094- return CausalLM .fallback (
1119+ return transformers_causal_lm_class .fallback (
10951120 model_id ,
10961121 revision ,
10971122 quantize = quantize ,
@@ -1116,7 +1141,7 @@ def get_model(
11161141 elif sharded :
11171142 raise NotImplementedError (FLASH_ATT_ERROR_MESSAGE .format ("Sharded Mixtral" ))
11181143 else :
1119- return CausalLM .fallback (
1144+ return transformers_causal_lm_class .fallback (
11201145 model_id ,
11211146 revision ,
11221147 quantize = quantize ,
@@ -1143,7 +1168,7 @@ def get_model(
11431168 FLASH_ATT_ERROR_MESSAGE .format ("Sharded Starcoder2" )
11441169 )
11451170 else :
1146- return CausalLM .fallback (
1171+ return transformers_causal_lm_class .fallback (
11471172 model_id ,
11481173 revision ,
11491174 quantize = quantize ,
@@ -1168,7 +1193,7 @@ def get_model(
11681193 elif sharded :
11691194 raise NotImplementedError (FLASH_ATT_ERROR_MESSAGE .format ("Sharded Qwen2" ))
11701195 else :
1171- return CausalLM .fallback (
1196+ return transformers_causal_lm_class .fallback (
11721197 model_id ,
11731198 revision ,
11741199 quantize = quantize ,
@@ -1329,7 +1354,7 @@ def get_model(
13291354 elif quantize == "exl2" :
13301355 raise NotImplementedError ("exl2 quantization is not supported for AutoModel" )
13311356 if model_type in modeling_auto .MODEL_FOR_CAUSAL_LM_MAPPING_NAMES :
1332- return CausalLM .fallback (
1357+ return transformers_causal_lm_class .fallback (
13331358 model_id ,
13341359 revision ,
13351360 quantize = quantize ,
@@ -1350,7 +1375,7 @@ def get_model(
13501375 auto_map = config_dict .get ("auto_map" , None )
13511376 if trust_remote_code and auto_map is not None :
13521377 if "AutoModelForCausalLM" in auto_map .keys ():
1353- return CausalLM .fallback (
1378+ return transformers_causal_lm_class .fallback (
13541379 model_id ,
13551380 revision ,
13561381 quantize = quantize ,
0 commit comments