-
Notifications
You must be signed in to change notification settings - Fork 283
Split fp8_fused_sdpa into two phases #2346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: aice/v122
Are you sure you want to change the base?
Conversation
|
cc @yangulei |
Co-authored-by: Youlei Yang <youlei.yang@intel.com> Signed-off-by: Bob Zhu <bob.zhu@intel.com>
|
The output of the APC example code is OK. |
|
@czhu15 Thank you for raising this enhancement, we will double check this change and ensure it's not breaking current usages. |
linoybu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the vLLM plugin, we are currently using FSDPA only during the prefill phase.
You can see this distinction here:
https://github.com/vllm-project/vllm-gaudi/blob/b8515d5fb8d5966768ad03e71bbbe1ad6661d7df/vllm_gaudi/attention/backends/hpu_attn.py#L262
It appears to be an attempt to separate decode and prefill operations to improve performance.
My question is: if we are not using FSDPA for decode, should we still expect any performance improvement?
Also, do you have a ticket that explains more about this issue?
|
Thank you for this contribution. @czhu15 |
|
Yes. This PR only applies only during the prefill phase. More specific, for the prefill phase when prefix caching is enabled. Current implementation is to pass a (big) atten_bias to the kernel, which can easily lead to OOM issue. |
Hi @xin3he This PR was targeted at aice/v122 or v3.6.post.oot for now. It’s okay to allow more flexibility in order to pursue ultimate performance. |
|
Hi @czhu15 please let me know once the local tests pass. I can help with the merge, or you’re welcome to do it yourself. |
Split fp8_fused_sdpa into two phases to decrease the TTFT.
The first phase will call fused_sdpa kernel w/o mask for prefix cached part.
The second phase will call fused_sdpa kernel with mask for the new prompt part.
Via splitting fp8_fused_sdpa into two phases, it decreases the memory consumption and also decreases the TTFT with current synapse fused_sdpa kernel.