Skip to content

[Feature Request] Fuse kda_gate_fwd and chunk_local_cumsum #670

@zhiyuan1i

Description

@zhiyuan1i

Feature Request

        if use_gate_in_kernel:
            g_org = g
            g = kda_gate_fwd(
                g=g_org,
                A_log=A_log,
                dt_bias=dt_bias,
            )
            g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices)
        else:
            g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices)

Motivation

N/A

Your Contribution

N/A

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions