Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import torch
import torch.nn as nn
import types
Expand Down Expand Up @@ -1296,6 +1297,7 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
self.register_scale("descale_amax", mod_extra_config.scale.inputs[3].type(torch.float32), self.scale_format)
self.register_scale("scale_output", 1 / mod_extra_config.scale.outputs[0].type(torch.float32), self.scale_format)
self.register_scale("scale_amax", 1 / self.descale_amax, self.scale_format)
self.fsdpa_split_thld = int(os.getenv("VLLM_FUSEDSDPA_SPLIT_THLD", 8192))

def forward_qdq(
self,
Expand Down Expand Up @@ -1331,6 +1333,41 @@ def forward_qdq(
)
return results

def fp8_fsdpa_fwd(self,
q,
k,
v,
attn_mask,
dropout_p,
scale,
is_causal,
softmax_mode,
):
results = torch.ops.hpu.fp8_sdpa_recomp_fwd(
q,
k,
v,
attn_mask,
dropout_p,
scale,
is_causal,
True, # requires_backward
softmax_mode, # softmax_mode
self.scale_q, # d_scale_q
self.scale_k, # d_scale_k
self.scale_v, # d_scale_v
self.scale_amax, # q_scale_s
self.scale_output, # q_scale_o
self.descale_amax, # d_scale_s
False, # is_amax_s
False, # is_amax_o
None, # valid_seq_len
"right", # seq_padding_type
(-1, -1), # window_size
None, # sink
)
return results

def forward_quant(
self,
q,
Expand All @@ -1349,28 +1386,76 @@ def forward_quant(
qinput = self.quant_q(q).detach()
kinput = self.quant_k(k).detach()
vinput = self.quant_v(v).detach()
results = self.fp8_fused_sdpa(
qinput,
kinput,
vinput,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
softmax_mode=sm_mode,
d_scale_q=self.scale_q,
d_scale_k=self.scale_k,
d_scale_v=self.scale_v,
q_scale_s=self.scale_amax,
q_scale_o=self.scale_output,
d_scale_s=self.descale_amax,
is_amax_s=False,
valid_seq_len=valid_seq_len,
seq_padding_type=seq_padding_type,
)
output = results[0]
d_out = self.dequant_output(output)
return d_out

q_len = q.shape[-2]
kv_len = kinput.size(-2)

# for prefill with prefix caching
if self.fsdpa_split_thld > 0 and q_len != 1 and q_len != kv_len and kv_len > self.fsdpa_split_thld:
assert attn_mask is not None, "Attention mask is required for FSDPA with prefix caching."
prefix_len = kv_len - q_len
from habana_frameworks.torch.hpex.kernels.Fp8FusedSDPA import is_gqa, gqa_input_reshape_fwd, gqa_output_reshape
gqa = is_gqa(qinput, kinput)
if gqa:
qinput, kinput, vinput, attn_mask = gqa_input_reshape_fwd(qinput, kinput, vinput, attn_mask)

prefix_kinput = kinput[..., 0:prefix_len, :]
prefix_vinput = vinput[..., 0:prefix_len, :]

# the new prompt part not in prefix caching
text_kinput = kinput[..., prefix_len:, :]
text_vinput = vinput[..., prefix_len:, :]

mask = attn_mask[..., -q_len:]

# calculate the first prefix sdpa w/o mask
prefix_results = self.fp8_fsdpa_fwd(qinput, prefix_kinput, prefix_vinput, None, dropout_p, scale, False, sm_mode)
prefix_out, prefix_m, prefix_linv = (gqa_output_reshape(x) for x in (prefix_results[:3])) if gqa else prefix_results[:3]

prefix_m = prefix_m.to(torch.float32)
prefix_linv = prefix_linv.to(torch.float32) * 128.0 if softmax_mode != "fp32" else prefix_linv.to(torch.float32)
prefix_out = self.dequant_output(prefix_out).to(torch.float32)

# calculate the second new prompt part with mask
text_results = self.fp8_fsdpa_fwd(qinput, text_kinput, text_vinput, mask, dropout_p, scale, False, sm_mode)
text_out, text_m, text_linv = (gqa_output_reshape(x) for x in (text_results[:3])) if gqa else text_results[:3]
text_m = text_m.to(torch.float32)
text_linv = text_linv.to(torch.float32) * 128.0 if softmax_mode != "fp32" else text_linv.to(torch.float32)
text_out = self.dequant_output(text_out).to(torch.float32)

new_m = torch.maximum(prefix_m, text_m)
prefix_linv_rescaled = (1.0 / prefix_linv) * torch.exp(prefix_m - new_m)
text_linv_rescaled = (1.0 / text_linv) * torch.exp(text_m - new_m)
final_linv = 1.0 / (prefix_linv_rescaled + text_linv_rescaled)
final_out = (prefix_linv_rescaled * final_linv) * prefix_out + (
text_linv_rescaled * final_linv) * text_out
# prefix_m = new_m

return final_out.to(q.dtype)

else:
results = self.fp8_fused_sdpa(
qinput,
kinput,
vinput,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
softmax_mode=sm_mode,
d_scale_q=self.scale_q,
d_scale_k=self.scale_k,
d_scale_v=self.scale_v,
q_scale_s=self.scale_amax,
q_scale_o=self.scale_output,
d_scale_s=self.descale_amax,
is_amax_s=False,
valid_seq_len=valid_seq_len,
seq_padding_type=seq_padding_type,
)
output = results[0]
d_out = self.dequant_output(output)
return d_out

def forward_measure(
self,
Expand Down