@@ -831,6 +831,9 @@ def forward_paged_tp_old(
831831
832832 def _attn_torch (self , batch_size , q_len , q_states , k_states , v_states , attn_params , cfg ):
833833
834+ num_attn_heads = q_states .shape [2 ]
835+ head_dim = q_states .shape [3 ]
836+
834837 q_states = q_states .transpose (1 , 2 )
835838 k_states = k_states .transpose (1 , 2 )
836839 v_states = v_states .transpose (1 , 2 )
@@ -881,7 +884,7 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para
881884 attn_output = torch .matmul (attn_weights , v_states )
882885
883886 attn_output = attn_output .transpose (1 , 2 )
884- attn_output = attn_output .reshape ((batch_size , q_len , cfg . num_attention_heads * cfg . head_dim ))
887+ attn_output = attn_output .reshape ((batch_size , q_len , num_attn_heads * head_dim ))
885888 return attn_output
886889
887890
@@ -955,8 +958,10 @@ def forward(self,
955958 loras : list [ExLlamaV2Lora ] | None = None ,
956959 ** kwargs ) -> torch .Tensor | dict [str : torch .Tensor ]:
957960
961+ cfg = self .model .config
958962 global has_flash_attn
959963 global has_xformers
964+ use_flash_attn = has_flash_attn and not cfg .no_flash_attn
960965
961966 if isinstance (attn_params , ExLlamaV2Attention .PagedParams ):
962967 return self .forward_paged (
@@ -968,7 +973,7 @@ def forward(self,
968973 )
969974
970975 if self .is_tp :
971- if cache is not None :
976+ if cache is not None and use_flash_attn :
972977 return self .forward_tp (
973978 hidden_states ,
974979 cache ,
@@ -1002,7 +1007,6 @@ def forward(self,
10021007 ** kwargs
10031008 )
10041009
1005- cfg = self .model .config
10061010 constants = self .model .get_device_context (self .device_idx )
10071011
10081012 batch_size , q_len , _ = hidden_states .shape
@@ -1193,7 +1197,10 @@ def forward_tp_old(
11931197
11941198 assert self .q_handle is not None
11951199 use_flash_attn = has_flash_attn and not cfg .no_flash_attn
1196- assert use_flash_attn , "Tensor parallel inference requires flash-attn"
1200+ if not use_flash_attn :
1201+ assert has_lower_right_sdpa and attn_params .is_causal () and not cfg .no_sdpa and not cfg .attn_logit_softcapping , \
1202+ "TP attention without flash-attn must use Torch SDPA with lower-right attention mask " \
1203+ "(use PyTorch 2.4.0+) and does not support logit softcapping."
11971204
11981205 hidden_states = self .model .tp_context .broadcast (0 , hidden_states , BROADCAST_KV , dim = cfg .head_dim )
11991206
@@ -1236,24 +1243,50 @@ def forward_tp_old(
12361243 torch .cuda .set_stream (context .stream )
12371244
12381245 if k_cache is not None :
1239- attn_output = flash_attn_with_kvcache (
1240- q = q [idx ],
1241- k = k [idx ],
1242- v = v [idx ],
1243- k_cache = k_cache [idx ],
1244- v_cache = v_cache [idx ],
1245- causal = True ,
1246- softmax_scale = self .scaling ,
1247- cache_seqlens = attn_params .past_len_tp [idx ]
1248- )
1246+ if use_flash_attn :
1247+ attn_output = flash_attn_with_kvcache (
1248+ q = q [idx ],
1249+ k = k [idx ],
1250+ v = v [idx ],
1251+ k_cache = k_cache [idx ],
1252+ v_cache = v_cache [idx ],
1253+ causal = True ,
1254+ softmax_scale = self .scaling ,
1255+ cache_seqlens = attn_params .past_len_tp [idx ]
1256+ )
1257+ else :
1258+ cache_a = attn_params .past_len
1259+ cache_b = attn_params .past_len + q_len
1260+ k_cache [idx ][:batch_size , cache_a :cache_b , :, :].copy_ (k [idx ])
1261+ v_cache [idx ][:batch_size , cache_a :cache_b , :, :].copy_ (v [idx ])
1262+ attn_output = self ._attn_torch (
1263+ batch_size ,
1264+ q_len ,
1265+ q [idx ],
1266+ k_cache [idx ][:batch_size , :cache_b , :, :],
1267+ v_cache [idx ][:batch_size , :cache_b , :, :],
1268+ attn_params ,
1269+ cfg
1270+ )
12491271 else :
1250- attn_output = flash_attn_func (
1251- q [idx ],
1252- k [idx ],
1253- v [idx ],
1254- causal = True ,
1255- softmax_scale = self .scaling ,
1256- )
1272+ if use_flash_attn :
1273+ attn_output = flash_attn_func (
1274+ q [idx ],
1275+ k [idx ],
1276+ v [idx ],
1277+ causal = True ,
1278+ softmax_scale = self .scaling ,
1279+ )
1280+ else :
1281+ attn_output = self ._attn_torch (
1282+ batch_size ,
1283+ q_len ,
1284+ q [idx ],
1285+ k [idx ],
1286+ v [idx ],
1287+ attn_params ,
1288+ cfg
1289+ )
12571290
12581291 attn_output = attn_output .view (batch_size * q_len , (b - a ) * cfg .head_dim * cfg .num_key_value_groups )
12591292 attn_outputs .append (attn_output )
0 commit comments