diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py index 05770f7b171..d50ea583807 100755 --- a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import torch import torch.nn as nn import types @@ -1296,6 +1297,7 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): self.register_scale("descale_amax", mod_extra_config.scale.inputs[3].type(torch.float32), self.scale_format) self.register_scale("scale_output", 1 / mod_extra_config.scale.outputs[0].type(torch.float32), self.scale_format) self.register_scale("scale_amax", 1 / self.descale_amax, self.scale_format) + self.fsdpa_split_thld = int(os.getenv("VLLM_FUSEDSDPA_SPLIT_THLD", 8192)) def forward_qdq( self, @@ -1331,6 +1333,41 @@ def forward_qdq( ) return results + def fp8_fsdpa_fwd(self, + q, + k, + v, + attn_mask, + dropout_p, + scale, + is_causal, + softmax_mode, + ): + results = torch.ops.hpu.fp8_sdpa_recomp_fwd( + q, + k, + v, + attn_mask, + dropout_p, + scale, + is_causal, + True, # requires_backward + softmax_mode, # softmax_mode + self.scale_q, # d_scale_q + self.scale_k, # d_scale_k + self.scale_v, # d_scale_v + self.scale_amax, # q_scale_s + self.scale_output, # q_scale_o + self.descale_amax, # d_scale_s + False, # is_amax_s + False, # is_amax_o + None, # valid_seq_len + "right", # seq_padding_type + (-1, -1), # window_size + None, # sink + ) + return results + def forward_quant( self, q, @@ -1349,28 +1386,76 @@ def forward_quant( qinput = self.quant_q(q).detach() kinput = self.quant_k(k).detach() vinput = self.quant_v(v).detach() - results = self.fp8_fused_sdpa( - qinput, - kinput, - vinput, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=is_causal, - scale=scale, - softmax_mode=sm_mode, - d_scale_q=self.scale_q, - d_scale_k=self.scale_k, - d_scale_v=self.scale_v, - q_scale_s=self.scale_amax, - q_scale_o=self.scale_output, - d_scale_s=self.descale_amax, - is_amax_s=False, - valid_seq_len=valid_seq_len, - seq_padding_type=seq_padding_type, - ) - output = results[0] - d_out = self.dequant_output(output) - return d_out + + q_len = q.shape[-2] + kv_len = kinput.size(-2) + + # for prefill with prefix caching + if self.fsdpa_split_thld > 0 and q_len != 1 and q_len != kv_len and kv_len > self.fsdpa_split_thld: + assert attn_mask is not None, "Attention mask is required for FSDPA with prefix caching." + prefix_len = kv_len - q_len + from habana_frameworks.torch.hpex.kernels.Fp8FusedSDPA import is_gqa, gqa_input_reshape_fwd, gqa_output_reshape + gqa = is_gqa(qinput, kinput) + if gqa: + qinput, kinput, vinput, attn_mask = gqa_input_reshape_fwd(qinput, kinput, vinput, attn_mask) + + prefix_kinput = kinput[..., 0:prefix_len, :] + prefix_vinput = vinput[..., 0:prefix_len, :] + + # the new prompt part not in prefix caching + text_kinput = kinput[..., prefix_len:, :] + text_vinput = vinput[..., prefix_len:, :] + + mask = attn_mask[..., -q_len:] + + # calculate the first prefix sdpa w/o mask + prefix_results = self.fp8_fsdpa_fwd(qinput, prefix_kinput, prefix_vinput, None, dropout_p, scale, False, sm_mode) + prefix_out, prefix_m, prefix_linv = (gqa_output_reshape(x) for x in (prefix_results[:3])) if gqa else prefix_results[:3] + + prefix_m = prefix_m.to(torch.float32) + prefix_linv = prefix_linv.to(torch.float32) * 128.0 if softmax_mode != "fp32" else prefix_linv.to(torch.float32) + prefix_out = self.dequant_output(prefix_out).to(torch.float32) + + # calculate the second new prompt part with mask + text_results = self.fp8_fsdpa_fwd(qinput, text_kinput, text_vinput, mask, dropout_p, scale, False, sm_mode) + text_out, text_m, text_linv = (gqa_output_reshape(x) for x in (text_results[:3])) if gqa else text_results[:3] + text_m = text_m.to(torch.float32) + text_linv = text_linv.to(torch.float32) * 128.0 if softmax_mode != "fp32" else text_linv.to(torch.float32) + text_out = self.dequant_output(text_out).to(torch.float32) + + new_m = torch.maximum(prefix_m, text_m) + prefix_linv_rescaled = (1.0 / prefix_linv) * torch.exp(prefix_m - new_m) + text_linv_rescaled = (1.0 / text_linv) * torch.exp(text_m - new_m) + final_linv = 1.0 / (prefix_linv_rescaled + text_linv_rescaled) + final_out = (prefix_linv_rescaled * final_linv) * prefix_out + ( + text_linv_rescaled * final_linv) * text_out + # prefix_m = new_m + + return final_out.to(q.dtype) + + else: + results = self.fp8_fused_sdpa( + qinput, + kinput, + vinput, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + softmax_mode=sm_mode, + d_scale_q=self.scale_q, + d_scale_k=self.scale_k, + d_scale_v=self.scale_v, + q_scale_s=self.scale_amax, + q_scale_o=self.scale_output, + d_scale_s=self.descale_amax, + is_amax_s=False, + valid_seq_len=valid_seq_len, + seq_padding_type=seq_padding_type, + ) + output = results[0] + d_out = self.dequant_output(output) + return d_out def forward_measure( self,