diff --git a/fla/__init__.py b/fla/__init__.py index 839221b4d..82d3ee49b 100644 --- a/fla/__init__.py +++ b/fla/__init__.py @@ -26,6 +26,7 @@ RodimusAttention, RWKV6Attention, RWKV7Attention, + StickBreakingAttention, ) from fla.models import ( ABCForCausalLM, @@ -74,6 +75,8 @@ RWKV6Model, RWKV7ForCausalLM, RWKV7Model, + StickBreakingAttentionForCausalLM, + StickBreakingAttentionModel, TransformerForCausalLM, TransformerModel, ) @@ -105,6 +108,7 @@ 'RodimusAttention', 'RodimusForCausalLM', 'RodimusModel', 'RWKV6Attention', 'RWKV6ForCausalLM', 'RWKV6Model', 'RWKV7Attention', 'RWKV7ForCausalLM', 'RWKV7Model', + 'StickBreakingAttention', 'StickBreakingAttentionForCausalLM', 'StickBreakingAttentionModel', ] __version__ = '0.4.0' diff --git a/fla/layers/__init__.py b/fla/layers/__init__.py index 5a23eac1d..7382d6906 100644 --- a/fla/layers/__init__.py +++ b/fla/layers/__init__.py @@ -30,6 +30,7 @@ from .rodimus import RodimusAttention, SlidingWindowSharedKeyAttention from .rwkv6 import RWKV6Attention from .rwkv7 import RWKV7Attention +from .stickbreaking_attn import StickBreakingAttention __all__ = [ 'ABCAttention', @@ -61,6 +62,7 @@ 'RodimusAttention', 'RWKV6Attention', 'RWKV7Attention', + 'StickBreakingAttention', 'SlidingWindowSharedKeyAttention', 'DeltaFormerAttention', ] diff --git a/fla/layers/stickbreaking_attn.py b/fla/layers/stickbreaking_attn.py new file mode 100644 index 000000000..fd81bbee0 --- /dev/null +++ b/fla/layers/stickbreaking_attn.py @@ -0,0 +1,108 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING + +import torch +import torch.nn as nn +from einops import rearrange +from transformers.utils import logging + +from fla.modules import RMSNorm +from fla.ops.stickbreaking_attn import parallel_stickbreaking_attn + +if TYPE_CHECKING: + from fla.models.utils import Cache + + +logger = logging.get_logger(__name__) + + +class StickBreakingAttention(nn.Module): + + def __init__( + self, + hidden_size: int = 2048, + num_heads: int = 32, + num_kv_heads: int | None = None, + qkv_bias: bool = False, + qk_norm: bool = False, + window_size: int | None = None, + max_position_embeddings: int | None = None, + layer_idx: int | None = None, + ): + super().__init__() + + if parallel_stickbreaking_attn is None: + raise ImportError( + "StickBreakingAttention kernels are not available. Ensure Triton is installed and ops are importable.", + ) + + self.hidden_size = hidden_size + self.num_heads = num_heads + if num_kv_heads is None: + self.num_kv_heads = self.num_heads + else: + self.num_kv_heads = num_kv_heads + self.num_kv_groups = self.num_heads // self.num_kv_heads + self.head_dim = self.hidden_size // self.num_heads + self.kv_dim = self.num_kv_heads * self.head_dim + self.qkv_bias = qkv_bias + self.qk_norm = qk_norm + + self.window_size = window_size + self.max_position_embeddings = max_position_embeddings + self.layer_idx = layer_idx + + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias) + self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias) + self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + + if qk_norm: + self.q_norm = RMSNorm(self.head_dim) + self.k_norm = RMSNorm(self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + if use_cache: + warnings.warn( + "StickBreakingAttention does not support KV cache yet; falling back to use_cache=False.") + use_cache = False + + batch_size, q_len, _ = hidden_states.size() + + q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) + k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) + v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) + + if self.qk_norm: + q, k = self.q_norm(q), self.k_norm(k) + + cu_seqlens = kwargs.get('cu_seqlens') + o, _rem = parallel_stickbreaking_attn( + q=q, + k=k, + v=v, + cu_seqlens=cu_seqlens, + ) + o = o.reshape(batch_size, q_len, -1) + o = self.o_proj(o) + + return o, None, past_key_values diff --git a/fla/models/__init__.py b/fla/models/__init__.py index 03fda347d..a95d7d1d2 100644 --- a/fla/models/__init__.py +++ b/fla/models/__init__.py @@ -31,6 +31,11 @@ from fla.models.rwkv6 import RWKV6Config, RWKV6ForCausalLM, RWKV6Model from fla.models.rwkv7 import RWKV7Config, RWKV7ForCausalLM, RWKV7Model from fla.models.samba import SambaConfig, SambaForCausalLM, SambaModel +from fla.models.stickbreaking_attn import ( + StickBreakingAttentionConfig, + StickBreakingAttentionForCausalLM, + StickBreakingAttentionModel, +) from fla.models.transformer import TransformerConfig, TransformerForCausalLM, TransformerModel __all__ = [ @@ -63,4 +68,5 @@ 'RWKV7Config', 'RWKV7ForCausalLM', 'RWKV7Model', 'SambaConfig', 'SambaForCausalLM', 'SambaModel', 'TransformerConfig', 'TransformerForCausalLM', 'TransformerModel', + 'StickBreakingAttentionConfig', 'StickBreakingAttentionForCausalLM', 'StickBreakingAttentionModel', ] diff --git a/fla/models/stickbreaking_attn/__init__.py b/fla/models/stickbreaking_attn/__init__.py new file mode 100644 index 000000000..f9fdf7297 --- /dev/null +++ b/fla/models/stickbreaking_attn/__init__.py @@ -0,0 +1,15 @@ + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.stickbreaking_attn.configuration_stickbreaking_attn import StickBreakingAttentionConfig +from fla.models.stickbreaking_attn.modeling_stickbreaking_attn import ( + StickBreakingAttentionForCausalLM, + StickBreakingAttentionModel, +) + +AutoConfig.register(StickBreakingAttentionConfig.model_type, StickBreakingAttentionConfig, exist_ok=True) +AutoModel.register(StickBreakingAttentionConfig, StickBreakingAttentionModel, exist_ok=True) +AutoModelForCausalLM.register(StickBreakingAttentionConfig, StickBreakingAttentionForCausalLM, exist_ok=True) + + +__all__ = ['StickBreakingAttentionConfig', 'StickBreakingAttentionForCausalLM', 'StickBreakingAttentionModel'] diff --git a/fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py b/fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py new file mode 100644 index 000000000..439cc38f2 --- /dev/null +++ b/fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py @@ -0,0 +1,82 @@ +import warnings + +from transformers.configuration_utils import PretrainedConfig + + +class StickBreakingAttentionConfig(PretrainedConfig): + + model_type = 'stickbreaking_attn' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + hidden_size: int = 2048, + num_hidden_layers: int = 24, + num_heads: int = 32, + num_kv_heads: int | None = None, + qkv_bias: bool = False, + qk_norm: bool = False, + window_size: int | None = None, + max_position_embeddings: int = 2048, + hidden_ratio: int | None = 4, + intermediate_size: int | None = None, + hidden_act: str = "swish", + initializer_range: float = 0.02, + elementwise_affine: bool | None = True, + norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: int | None = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + fuse_linear_cross_entropy: bool = False, + use_l2warp: bool = False, + vocab_size: int = 32000, + **kwargs, + ): + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.qkv_bias = qkv_bias + self.qk_norm = qk_norm + self.window_size = window_size + self.max_position_embeddings = max_position_embeddings + + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + + self.initializer_range = initializer_range + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_cache = use_cache + + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.fuse_linear_cross_entropy = fuse_linear_cross_entropy + self.use_l2warp = use_l2warp + self.vocab_size = vocab_size + + if fuse_cross_entropy and fuse_linear_cross_entropy: + raise ValueError( + "`fuse_cross_entropy` and `fuse_linear_cross_entropy` cannot be True at the same time.", + ) + if fuse_linear_cross_entropy: + warnings.warn( + "`fuse_linear_cross_entropy` is enabled, which can improves memory efficiency " + "at the potential cost of reduced precision. " + "If you observe issues like loss divergence, consider disabling this setting.", + ) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py b/fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py new file mode 100644 index 000000000..a72e1ae0e --- /dev/null +++ b/fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py @@ -0,0 +1,340 @@ +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Any + +import torch +import torch.nn as nn +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.stickbreaking_attn import StickBreakingAttention +from fla.models.stickbreaking_attn.configuration_stickbreaking_attn import StickBreakingAttentionConfig +from fla.models.utils import Cache, FLAGenerationMixin +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm +from fla.modules import GatedMLP as SBAttnMLP +from fla.modules.l2warp import l2_warp + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + +try: + from transformers.modeling_layers import GradientCheckpointingLayer +except ImportError: + from fla.models.modeling_layers import GradientCheckpointingLayer + +logger = logging.get_logger(__name__) + + +class StickBreakingAttentionBlock(GradientCheckpointingLayer): + + def __init__(self, config: StickBreakingAttentionConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.attn = StickBreakingAttention( + hidden_size=config.hidden_size, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + qkv_bias=config.qkv_bias, + qk_norm=config.qk_norm, + window_size=config.window_size, + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx, + ) + + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = SBAttnMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + past_key_values: tuple[torch.Tensor] | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + **kwargs: Unpack[Any], + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs, + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attentions,) + + if use_cache: + outputs += (past_key_values,) + + return outputs + + +class StickBreakingAttentionPreTrainedModel(PreTrainedModel): + + config_class = StickBreakingAttentionConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['StickBreakingAttentionBlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = False, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if rescale_prenorm_residual: + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class StickBreakingAttentionModel(StickBreakingAttentionPreTrainedModel): + + def __init__( + self, + config: StickBreakingAttentionConfig, + ) -> StickBreakingAttentionModel: + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + layers = [] + for layer_idx in range(config.num_hidden_layers): + layers.append(StickBreakingAttentionBlock(config, layer_idx)) + self.layers = nn.ModuleList(layers) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + **kwargs: Unpack[Any], + ) -> tuple | CausalLMOutputWithPast: + if output_attentions: + warnings.warn( + "`sba` does not support output attention weights now, so `output_attentions` is set to `False`.", + ) + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + hidden_states = inputs_embeds + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + next_cache = None + + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_attns, + ) + + +class StickBreakingAttentionForCausalLM(StickBreakingAttentionPreTrainedModel, FLAGenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = StickBreakingAttentionModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + past_key_values: Cache | list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + logits_to_keep: int | None = 0, + **kwargs: Unpack[Any], + ) -> tuple | CausalLMOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs, + ) + + hidden_states = outputs[0] + + logits = None if self.config.fuse_linear_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:]) + + loss = None + if labels is not None: + if getattr(self, 'criterion', None) is None: + if self.config.fuse_linear_cross_entropy: + criterion = FusedLinearCrossEntropyLoss(use_l2warp=self.config.use_l2warp) + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if self.config.fuse_linear_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + loss = l2_warp(loss, logits) if self.config.use_l2warp else loss + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla/ops/__init__.py b/fla/ops/__init__.py index c11ec7832..a773016f4 100644 --- a/fla/ops/__init__.py +++ b/fla/ops/__init__.py @@ -7,10 +7,10 @@ from .forgetting_attn import parallel_forgetting_attn from .gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule from .generalized_delta_rule import ( - chunk_dplr_delta_rule, - chunk_iplr_delta_rule, - fused_recurrent_dplr_delta_rule, - fused_recurrent_iplr_delta_rule, + chunk_dplr_delta_rule, + chunk_iplr_delta_rule, + fused_recurrent_dplr_delta_rule, + fused_recurrent_iplr_delta_rule, ) from .gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla from .gsa import chunk_gsa, fused_recurrent_gsa @@ -26,6 +26,7 @@ from .rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6 from .rwkv7 import chunk_rwkv7, fused_recurrent_rwkv7 from .simple_gla import chunk_simple_gla, fused_chunk_simple_gla, fused_recurrent_simple_gla, parallel_simple_gla +from .stickbreaking_attn.parallel import parallel_stickbreaking_attn __all__ = [ 'chunk_abc', @@ -51,4 +52,5 @@ 'chunk_rwkv6', 'fused_recurrent_rwkv6', 'chunk_rwkv7', 'fused_recurrent_rwkv7', 'chunk_simple_gla', 'fused_chunk_simple_gla', 'fused_recurrent_simple_gla', 'parallel_simple_gla', + 'parallel_stickbreaking_attn', ] diff --git a/fla/ops/stickbreaking_attn/__init__.py b/fla/ops/stickbreaking_attn/__init__.py new file mode 100644 index 000000000..16472bf23 --- /dev/null +++ b/fla/ops/stickbreaking_attn/__init__.py @@ -0,0 +1,7 @@ +from .naive import naive_stickbreaking_attn +from .parallel import parallel_stickbreaking_attn + +__all__ = [ + 'parallel_stickbreaking_attn', + 'naive_stickbreaking_attn', +] diff --git a/fla/ops/stickbreaking_attn/naive.py b/fla/ops/stickbreaking_attn/naive.py new file mode 100644 index 000000000..4872ecc67 --- /dev/null +++ b/fla/ops/stickbreaking_attn/naive.py @@ -0,0 +1,50 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch +import torch.nn.functional as F + + +def naive_stickbreaking_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Naive stick-breaking attention reference implementation. + + Args: + q, k, v: [B, T, H, D] + scale: inverse temperature (1/sqrt(D)) + Returns: + o: [B, T, H, D] + rem: [B, T, H] (1 - sum of attention up to t) + """ + _, T, _, D = q.shape + orig_dtype = q.dtype + if scale is None: + scale = D ** -0.5 + + logits = torch.einsum('bthd,bshd->bhts', q, k) * scale + logits = logits.float() + + mask = torch.ones(T, T, device=q.device).triu(0).bool() # exclude diagonal + mask = mask.unsqueeze(0).unsqueeze(0) # [1, 1, T, T] + + log_z = F.logsigmoid(logits).masked_fill(mask, -1e5).to(orig_dtype) + log_beta = F.logsigmoid(-logits).masked_fill(mask, 0).to(orig_dtype) + + cum_weight = torch.ones(T, T, device=q.device).tril(-1) + + re_cum_log_beta = torch.einsum("bhij,jk->bhik", log_beta, cum_weight.to(log_beta)) + log_att = log_z + re_cum_log_beta + att = log_att.exp() + o = torch.einsum('bhts,bshd->bthd', att, v) + rem = 1 - att.sum(dim=-1).transpose(1, 2) + + return o.to(orig_dtype), rem.to(orig_dtype) + + +__all__ = [ + 'naive_stickbreaking_attn', +] diff --git a/fla/ops/stickbreaking_attn/parallel.py b/fla/ops/stickbreaking_attn/parallel.py new file mode 100644 index 000000000..8ac9c7bd0 --- /dev/null +++ b/fla/ops/stickbreaking_attn/parallel.py @@ -0,0 +1,713 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch +import triton +import triton.language as tl + +from fla.ops.stickbreaking_attn.softplus import softplus +from fla.ops.utils.index import prepare_chunk_indices +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous + +ALLOW_TF32 = True + + +@triton.jit +def load_kv(K_blk_ptrs, V_blk_ptrs, N_mask, NO_N_MASK, D_mask, NO_D_MASK: tl.constexpr): + if NO_D_MASK: + if NO_N_MASK: + k = tl.load(K_blk_ptrs) + v = tl.load(V_blk_ptrs) + else: + k = tl.load(K_blk_ptrs, mask=N_mask[:, None]) + v = tl.load(V_blk_ptrs, mask=N_mask[:, None]) + else: + mask = N_mask[:, None] & D_mask[None, :] + k = tl.load(K_blk_ptrs, mask=mask) + v = tl.load(V_blk_ptrs, mask=mask) + return k, v + + +@triton.jit +def compute_block( + q, + k, + qk_scale, + neg_log_acc, + M_blk_idxs, + N_blk_idxs, + cm, + on_band: tl.constexpr, + ALLOW_TF32: tl.constexpr, + backward: tl.constexpr, + use_cumsum: tl.constexpr = False, +): + qk = tl.dot(q, tl.trans(k), allow_tf32=ALLOW_TF32) * qk_scale + log_om_beta = -softplus(qk) + + if on_band: + block_mask = M_blk_idxs[:, None] > N_blk_idxs[None, :] + log_om_beta = tl.where(block_mask, log_om_beta, 0.0) + if backward: + neg_log_acc -= tl.sum(log_om_beta, axis=1) + log_p = qk + neg_log_acc[:, None] + if use_cumsum: + log_p += tl.cumsum(log_om_beta.to(q.dtype), axis=1, reverse=True) + else: + log_p = tl.dot(log_om_beta.to(q.dtype), cm, acc=log_p, allow_tf32=ALLOW_TF32) + p = tl.math.exp2(log_p) + p = tl.where(block_mask, p, 0.0) + else: + if backward: + neg_log_acc -= tl.sum(log_om_beta, axis=1) + log_p = qk + neg_log_acc[:, None] + if use_cumsum: + log_p += tl.cumsum(log_om_beta.to(q.dtype), axis=1, reverse=True) + else: + log_p = tl.dot(log_om_beta.to(q.dtype), cm, acc=log_p, allow_tf32=ALLOW_TF32) + p = tl.math.exp2(log_p) + if not backward: + neg_log_acc += tl.sum(log_om_beta, axis=1) + return p, log_om_beta, neg_log_acc + + +@triton.jit +def stickbreaking_attn_fwd_one_row_kernel( + seq_block_id, + seq_length, + qk_scale, + M_range, + N_range, + D_range, + D_mask, + cm, + Q_head_seq_ptr, + K_head_seq_ptr, + V_head_seq_ptr, + O_head_seq_ptr, + R_head_seq_ptr, + A_head_seq_ptr, + head_size: tl.constexpr, + H: tl.constexpr, + BLOCK_D: tl.constexpr, + NO_D_MASK: tl.constexpr, + NO_M_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + no_grad: tl.constexpr = False, + acc_dtype: tl.constexpr = tl.float32, + return_attention: tl.constexpr = False, +): + block_start_offset = BT * seq_block_id + M_blk_idxs = block_start_offset + M_range + M_mask = M_blk_idxs < seq_length + N_blk_idxs_start = block_start_offset + BT + N_blk_idxs = N_blk_idxs_start + N_range + + Q_blk_ptrs = Q_head_seq_ptr + ( + (H * head_size) * M_blk_idxs[:, None] + 1 * D_range[None, :] + ) + K_blk_ptrs = K_head_seq_ptr + ( + (H * head_size) * N_blk_idxs[:, None] + 1 * D_range[None, :] + ) + V_blk_ptrs = V_head_seq_ptr + ( + (H * head_size) * N_blk_idxs[:, None] + 1 * D_range[None, :] + ) + O_blk_ptrs = O_head_seq_ptr + ( + (H * head_size) * M_blk_idxs[:, None] + 1 * D_range[None, :] + ) + R_blk_ptrs = R_head_seq_ptr + H * M_blk_idxs + A_blk_ptrs = A_head_seq_ptr + H * M_blk_idxs + + if NO_D_MASK: + if NO_M_MASK: + q = tl.load(Q_blk_ptrs) + else: + q = tl.load(Q_blk_ptrs, mask=M_mask[:, None], other=0.0) + else: + q = tl.load(Q_blk_ptrs, mask=M_mask[:, None] & D_mask[None, :], other=0.0) + + iters = N_blk_idxs_start // BS + neg_log_acc = tl.zeros([BT], dtype=acc_dtype) + acc = tl.zeros([BT, BLOCK_D], dtype=acc_dtype) + + for i in range(iters): + N_blk_idxs -= BS + N_blk_idxs_start -= BS + K_blk_ptrs -= BS * (H * head_size) + V_blk_ptrs -= BS * (H * head_size) + + N_mask = N_blk_idxs < seq_length + k, v = load_kv( + K_blk_ptrs, + V_blk_ptrs, + N_mask=N_mask, + NO_N_MASK=N_blk_idxs_start + BS - 1 < seq_length, + D_mask=D_mask, + NO_D_MASK=NO_D_MASK, + ) + on_band = i < BT // BS + p, _log_om_beta, neg_log_acc = compute_block( + q, + k, + qk_scale, + neg_log_acc, + M_blk_idxs, + N_blk_idxs, + cm, + on_band, + ALLOW_TF32, + backward=False, + use_cumsum=False, + ) + acc = tl.dot(p.to(v.dtype), v, acc, allow_tf32=ALLOW_TF32) + + if NO_M_MASK: + tl.store(R_blk_ptrs, tl.math.exp2(neg_log_acc)) + tl.store(A_blk_ptrs, neg_log_acc.to(A_head_seq_ptr.type.element_ty)) + else: + tl.store(R_blk_ptrs, tl.math.exp2(neg_log_acc), mask=M_mask) + tl.store(A_blk_ptrs, neg_log_acc.to(A_head_seq_ptr.type.element_ty), mask=M_mask) + if NO_D_MASK: + tl.store(O_blk_ptrs, acc.to(O_head_seq_ptr.type.element_ty), mask=M_mask[:, None]) + else: + tl.store(O_blk_ptrs, acc.to(O_head_seq_ptr.type.element_ty), mask=M_mask[:, None] & D_mask[None, :]) + + +@triton.jit +def stickbreaking_attn_bwd_one_row_kernel( + seq_prog_id, + seq_length, + qk_scale, + M_range, + N_range, + D_range, + D_mask, + cm, + DO_head_seq_ptr, + DR_head_seq_ptr, + A_head_seq_ptr, + Q_head_seq_ptr, + K_head_seq_ptr, + V_head_seq_ptr, + DQ_head_seq_ptr, + DK_head_seq_ptr, + DV_head_seq_ptr, + scale, + head_size: tl.constexpr, + H: tl.constexpr, + BLOCK_D: tl.constexpr, + NO_D_MASK: tl.constexpr, + NO_M_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + acc_dtype: tl.constexpr = tl.float32, +): + block_start_offset = BT * seq_prog_id + M_blk_idxs = block_start_offset + M_range + M_mask = M_blk_idxs < seq_length + + N_blk_idxs_start = 0 + N_blk_idxs = N_blk_idxs_start + N_range + + DO_blk_ptrs = DO_head_seq_ptr + ( + (H * head_size) * M_blk_idxs[:, None] + 1 * D_range[None, :] + ) + K_blk_ptrs = K_head_seq_ptr + ( + (H * head_size) * N_blk_idxs[:, None] + 1 * D_range[None, :] + ) + Q_blk_ptrs = Q_head_seq_ptr + ( + (H * head_size) * M_blk_idxs[:, None] + 1 * D_range[None, :] + ) + V_blk_ptrs = V_head_seq_ptr + ( + (H * head_size) * N_blk_idxs[:, None] + 1 * D_range[None, :] + ) + A_blk_ptrs = A_head_seq_ptr + H * M_blk_idxs + DQ_blk_ptrs = DQ_head_seq_ptr + ( + (H * head_size) * M_blk_idxs[:, None] + 1 * D_range[None, :] + ) + DK_blk_ptrs = DK_head_seq_ptr + ( + (H * head_size) * N_blk_idxs[:, None] + 1 * D_range[None, :] + ) + DV_blk_ptrs = DV_head_seq_ptr + ( + (H * head_size) * N_blk_idxs[:, None] + 1 * D_range[None, :] + ) + DR_blk_ptrs = DR_head_seq_ptr + H * M_blk_idxs + + if NO_D_MASK: + if NO_N_MASK: + q = tl.load(Q_blk_ptrs) + do = tl.load(DO_blk_ptrs) + dr = tl.load(DR_blk_ptrs) + neg_log_acc = tl.load(A_blk_ptrs, mask=M_mask) + else: + q = tl.load(Q_blk_ptrs, mask=M_mask[:, None]) + do = tl.load(DO_blk_ptrs, mask=M_mask[:, None]) + dr = tl.load(DR_blk_ptrs, mask=M_mask) + neg_log_acc = tl.load(A_blk_ptrs, mask=M_mask) + else: + MD_mask = M_mask[:, None] & D_mask[None, :] + q = tl.load(Q_blk_ptrs, mask=MD_mask) + do = tl.load(DO_blk_ptrs, mask=MD_mask) + dr = tl.load(DR_blk_ptrs, mask=M_mask) + neg_log_acc = tl.load(A_blk_ptrs, mask=M_mask) + + neg_log_acc = neg_log_acc.to(dtype=acc_dtype) + grad_prev_acc = tl.zeros((BT,), dtype=acc_dtype) + dq = tl.zeros((BT, BLOCK_D), dtype=acc_dtype) + + fwd_cm = tl.trans(cm) + iters = (block_start_offset + BT) // BS + for i in range(iters): + on_band = (iters - i - 1) < BT // BS + N_mask = N_blk_idxs < seq_length + local_no_n_mask = (N_blk_idxs_start + BS - 1) < seq_length + k, v = load_kv( + K_blk_ptrs, + V_blk_ptrs, + N_mask=N_mask, + NO_N_MASK=local_no_n_mask, + D_mask=D_mask, + NO_D_MASK=NO_D_MASK, + ) + p, log_om_beta, neg_log_acc = compute_block( + q, + k, + qk_scale, + neg_log_acc, + M_blk_idxs, + N_blk_idxs, + cm, + on_band, + ALLOW_TF32, + backward=True, + ) + + if not NO_M_MASK: + neg_log_acc = tl.where(M_mask, neg_log_acc, 0.0) + + att_dA = p * (tl.dot(do, tl.trans(v), allow_tf32=ALLOW_TF32) - dr[:, None]) + cumul_att_dA = tl.dot(att_dA.to(cm.dtype), fwd_cm, allow_tf32=ALLOW_TF32) + grad_prev_acc[:, None] + grad_prev_acc += tl.sum(att_dA, axis=1) + beta = 1 - tl.exp2(log_om_beta) + dqk = att_dA - beta * cumul_att_dA + + dq = tl.dot(dqk.to(k.dtype), k, acc=dq, allow_tf32=ALLOW_TF32) + block_dk = tl.dot(tl.trans(dqk).to(q.dtype), q, allow_tf32=ALLOW_TF32) * scale + block_dv = tl.dot(tl.trans(p), do.to(p.dtype), allow_tf32=ALLOW_TF32) + + if NO_D_MASK: + tl.store(DK_blk_ptrs, block_dk, mask=N_mask[:, None]) + tl.store(DV_blk_ptrs, block_dv, mask=N_mask[:, None]) + else: + mask = N_mask[:, None] & D_mask[None, :] + tl.store(DK_blk_ptrs, block_dk, mask=mask) + tl.store(DV_blk_ptrs, block_dv, mask=mask) + + N_blk_idxs += BS + N_blk_idxs_start += BS + K_blk_ptrs += BS * (H * head_size) + V_blk_ptrs += BS * (H * head_size) + DK_blk_ptrs += BS * (H * head_size) + DV_blk_ptrs += BS * (H * head_size) + + dq = (scale * dq).to(DQ_head_seq_ptr.type.element_ty) + + if NO_D_MASK: + tl.store(DQ_blk_ptrs, dq, mask=M_mask[:, None]) + else: + tl.store(DQ_blk_ptrs, dq, mask=M_mask[:, None] & D_mask[None, :]) + + +@triton.autotune( + configs=[ + triton.Config({}, num_stages=s, num_warps=w) + for s in [4] + for w in [4] + ], + key=["T", "head_size"] +) +@triton.jit +def parallel_stickbreaking_attn_fwd_kernel( + Q_ptr, + K_ptr, + V_ptr, + O_ptr, + R_ptr, + A_ptr, + CU_ptr, + CI_ptr, + scale: tl.constexpr, + B, + T, + head_size: tl.constexpr, + H: tl.constexpr, + BLOCK_D: tl.constexpr, + NO_D_MASK: tl.constexpr, + NO_M_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + no_grad: tl.constexpr = False, + acc_dtype: tl.constexpr = tl.float32, + IS_VARLEN: tl.constexpr = False, +): + tl.static_assert(BT % BS == 0) + batch_id = 0 if IS_VARLEN else tl.program_id(0) + head_pid = tl.program_id(1) + prog_id = tl.program_id(2) + tl.num_programs(2) + if IS_VARLEN: + i_n = tl.load(CI_ptr + prog_id * 2).to(tl.int32) + seq_block_id = tl.load(CI_ptr + prog_id * 2 + 1).to(tl.int32) + bos = tl.load(CU_ptr + i_n).to(tl.int32) + eos = tl.load(CU_ptr + i_n + 1).to(tl.int32) + seq_length = eos - bos + else: + bos = tl.full([], 0, dtype=tl.int32) + seq_block_id = prog_id + seq_length = T + RCP_LN2: tl.constexpr = 1.4426950216 + + qk_scale = RCP_LN2 * scale + M_range = tl.arange(0, BT) + N_range = tl.arange(0, BS) + D_range = tl.arange(0, BLOCK_D) + D_mask = D_range < head_size + cm = tl.where(N_range[:, None] >= N_range[None, :], 1.0, 0.0).to(Q_ptr.type.element_ty) + + head_id = head_pid + seq_prog_id = seq_block_id + batch_offset = batch_id * T + Q_head_seq_ptr = Q_ptr + ((batch_offset + bos) * H + head_id) * head_size + K_head_seq_ptr = K_ptr + ((batch_offset + bos) * H + head_id) * head_size + V_head_seq_ptr = V_ptr + ((batch_offset + bos) * H + head_id) * head_size + O_head_seq_ptr = O_ptr + ((batch_offset + bos) * H + head_id) * head_size + R_head_seq_ptr = R_ptr + ((batch_offset + bos) * H + head_id) + A_head_seq_ptr = A_ptr + ((batch_offset + bos) * H + head_id) + + stickbreaking_attn_fwd_one_row_kernel( + seq_prog_id, + seq_length, + qk_scale, + M_range, + N_range, + D_range, + D_mask, + cm, + Q_head_seq_ptr, + K_head_seq_ptr, + V_head_seq_ptr, + O_head_seq_ptr, + R_head_seq_ptr, + A_head_seq_ptr, + head_size, + H, + BLOCK_D, + NO_D_MASK, + NO_M_MASK, + NO_N_MASK, + ALLOW_TF32, + BT, + BS, + no_grad, + acc_dtype, + False, + ) + + +@triton.autotune( + configs=[ + triton.Config({}, num_stages=s, num_warps=w) + for s in [8] + for w in [4] + ], + key=["T", "head_size"] +) +@triton.jit() +def parallel_stickbreaking_attn_bwd_kernel( + DO_ptr, + DR_ptr, + A_ptr, + Q_ptr, + K_ptr, + V_ptr, + DQ_ptr, + DK_ptr, + DV_ptr, + CU_ptr, + CI_ptr, + scale, + B, + T, + head_size: tl.constexpr, + H: tl.constexpr, + BLOCK_D: tl.constexpr, + NO_D_MASK: tl.constexpr, + NO_M_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + acc_dtype: tl.constexpr = tl.float32, + IS_VARLEN: tl.constexpr = False, +): + tl.static_assert(BT % BS == 0) + batch_id = 0 if IS_VARLEN else tl.program_id(0) + head_pid = tl.program_id(1) + prog_id = tl.program_id(2) + RCP_LN2: tl.constexpr = 1.4426950216 + + qk_scale = RCP_LN2 * scale + M_range = tl.arange(0, BT) + N_range = tl.arange(0, BS) + D_range = tl.arange(0, BLOCK_D) + D_mask = D_range < head_size + cm = tl.where(N_range[:, None] >= N_range[None, :], 1.0, 0.0).to(Q_ptr.type.element_ty) + + if IS_VARLEN: + i_n = tl.load(CI_ptr + prog_id * 2).to(tl.int32) + seq_block_id = tl.load(CI_ptr + prog_id * 2 + 1).to(tl.int32) + bos = tl.load(CU_ptr + i_n).to(tl.int32) + eos = tl.load(CU_ptr + i_n + 1).to(tl.int32) + seq_length = eos - bos + else: + bos = 0 + seq_block_id = prog_id + seq_length = T + + head_id = head_pid + seq_prog_id = seq_block_id + + batch_id_i64 = batch_id.to(tl.int64) + head_id_i64 = head_id.to(tl.int64) + seq_prog_id_i64 = seq_prog_id.to(tl.int64) + bos_i64 = bos.to(tl.int64) + + batch_offset = batch_id_i64 * T + head_offset = (batch_offset + bos_i64) * H + head_id_i64 + block_offset = seq_prog_id_i64 * B * T * H + + DO_head_seq_ptr = DO_ptr + head_offset * head_size + DR_head_seq_ptr = DR_ptr + head_offset + A_head_seq_ptr = A_ptr + head_offset + Q_head_seq_ptr = Q_ptr + head_offset * head_size + K_head_seq_ptr = K_ptr + head_offset * head_size + V_head_seq_ptr = V_ptr + head_offset * head_size + DQ_head_seq_ptr = DQ_ptr + head_offset * head_size + DK_head_seq_ptr = DK_ptr + (block_offset + head_offset) * head_size + DV_head_seq_ptr = DV_ptr + (block_offset + head_offset) * head_size + + stickbreaking_attn_bwd_one_row_kernel( + seq_prog_id, + seq_length, + qk_scale, + M_range, + N_range, + D_range, + D_mask, + cm, + DO_head_seq_ptr, + DR_head_seq_ptr, + A_head_seq_ptr, + Q_head_seq_ptr, + K_head_seq_ptr, + V_head_seq_ptr, + DQ_head_seq_ptr, + DK_head_seq_ptr, + DV_head_seq_ptr, + scale, + head_size, + H, + BLOCK_D, + NO_D_MASK, + NO_M_MASK, + NO_N_MASK, + ALLOW_TF32, + BT, + BS, + acc_dtype, + ) + + +def parallel_stickbreaking_attn_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float, + cu_seqlens: torch.LongTensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Run forward Triton kernel and return (o, rem, neg_log_acc). + + q, k, v: [B, T, H, D] + Returns: o [B, T, H, D], rem [B, T, H], neg_log_acc [B, T, H] + """ + B, T, H, D = q.size() + o = torch.empty_like(q) + rem = torch.zeros_like(q[:, :, :, 0], device=q.device) + neg_log_acc = torch.zeros_like(rem, device=q.device, dtype=torch.float32) + + BT = 64 + BS = 64 + if cu_seqlens is None: + NT = triton.cdiv(T, BT) + grid = (B, H, NT) + CI = None + else: + CI = prepare_chunk_indices(cu_seqlens, BT) + NT = int(CI.shape[0]) + grid = (1, H, NT) + BLOCK_D = triton.next_power_of_2(D) + + NO_M_MASK = (T % BT) == 0 + NO_N_MASK = (T % BS) == 0 + if cu_seqlens is not None: + NO_M_MASK = False + NO_N_MASK = False + + parallel_stickbreaking_attn_fwd_kernel[grid]( + q, + k, + v, + o, + rem, + neg_log_acc, + CU_ptr=cu_seqlens if cu_seqlens is not None else q, + CI_ptr=CI if CI is not None else q, + scale=scale, + B=B, + T=T, + head_size=D, + H=H, + BLOCK_D=BLOCK_D, + NO_D_MASK=D == BLOCK_D, + NO_M_MASK=NO_M_MASK, + NO_N_MASK=NO_N_MASK, + ALLOW_TF32=ALLOW_TF32, + BT=BT, + BS=BS, + no_grad=False, + IS_VARLEN=cu_seqlens is not None, + ) + + return o, rem, neg_log_acc + + +def parallel_stickbreaking_attn_bwd( + do: torch.Tensor, + dr: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + neg_log_acc: torch.Tensor, + scale: float, + cu_seqlens: torch.LongTensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, H, D = q.size() + BT = 64 + BS = 64 + if cu_seqlens is None: + M_count = triton.cdiv(T, BT) + grid = (B, H, M_count) + CI = None + else: + CI = prepare_chunk_indices(cu_seqlens, BT) + M_count = int(CI.shape[0]) + grid = (1, H, M_count) + dq = torch.zeros_like(q) + dk = torch.zeros((M_count, B, T, H, D), dtype=k.dtype, device=k.device) + dv = torch.zeros((M_count, B, T, H, D), dtype=v.dtype, device=v.device) + + BLOCK_D = triton.next_power_of_2(D) + + NO_M_MASK = (T % BT) == 0 + NO_N_MASK = (T % BS) == 0 + if cu_seqlens is not None: + NO_M_MASK = False + NO_N_MASK = False + + parallel_stickbreaking_attn_bwd_kernel[grid]( + do, + dr, + neg_log_acc, + q, + k, + v, + dq, + dk, + dv, + CU_ptr=cu_seqlens if cu_seqlens is not None else q, + CI_ptr=CI if CI is not None else q, + scale=scale, + B=B, + T=T, + head_size=D, + H=H, + BT=BT, + BS=BS, + BLOCK_D=BLOCK_D, + NO_D_MASK=D == BLOCK_D, + NO_M_MASK=NO_M_MASK, + NO_N_MASK=NO_N_MASK, + ALLOW_TF32=ALLOW_TF32, + acc_dtype=tl.float32, + IS_VARLEN=cu_seqlens is not None, + ) + + dk_final = dk.sum(0) + dv_final = dv.sum(0) + + return dq.to(q.dtype), dk_final, dv_final + + +class StickBreakingAttentionFunction(torch.autograd.Function): + + @staticmethod + @staticmethod + @contiguous + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float, + cu_seqlens: torch.LongTensor | None = None, + ): + o, rem, neg_log_acc = parallel_stickbreaking_attn_fwd(q, k, v, scale, cu_seqlens) + ctx.save_for_backward(q, k, v, neg_log_acc) + ctx.scale = scale + ctx.cu_seqlens = cu_seqlens + return o, rem + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do: torch.Tensor, drem: torch.Tensor): + q, k, v, neg_log_acc = ctx.saved_tensors + dq, dk, dv = parallel_stickbreaking_attn_bwd(do, drem, q, k, v, neg_log_acc, ctx.scale, ctx.cu_seqlens) + return dq, dk, dv, None, None + + +def parallel_stickbreaking_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float | None = None, + cu_seqlens: torch.LongTensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + return StickBreakingAttentionFunction.apply(q, k, v, scale, cu_seqlens) + + +__all__ = [ + 'parallel_stickbreaking_attn', +] diff --git a/fla/ops/stickbreaking_attn/softplus.py b/fla/ops/stickbreaking_attn/softplus.py new file mode 100644 index 000000000..20d306718 --- /dev/null +++ b/fla/ops/stickbreaking_attn/softplus.py @@ -0,0 +1,49 @@ +# COPIED FROM +# https://github.com/shawntan/stickbreaking-attention/blob/main/stickbreaking_attention/sb_varlen/softplus.py + +import triton +from triton import language as tl + + +def _generate_asm(num_pack): + template = """ + .reg .pred p; + setp.gt.f32 p, ${in_reg}, 15.; + @p mov.f32 ${out_reg}, ${in_reg}; + @!p ex2.approx.ftz.f32 ${out_reg}, ${in_reg}; + @!p add.f32 ${out_reg}, ${out_reg}, 1.0; + @!p lg2.approx.ftz.f32 ${out_reg}, ${out_reg}; + """ + out_str = "" + + for i in range(num_pack): + inner_str = template.format(out_reg=i, in_reg=i + num_pack) + out_str += "{" + inner_str + "}\n" + # flatten out because torch.compile doesn't like newlines + out_str = " ".join(out_str.split("\n")) + return out_str + + +def _generate_constraints(num_pack): + return ",".join("=r" for i in range(num_pack)) + "," + ",".join("r" for i in range(num_pack)) + + +_NUM_REG = 1 +asm_str: tl.constexpr = tl.constexpr(_generate_asm(_NUM_REG)) +constraints_str: tl.constexpr = tl.constexpr(_generate_constraints(_NUM_REG)) +NUM_REG: tl.constexpr = tl.constexpr(_NUM_REG) + + +@triton.jit +def softplus(x): + # return tl.where(x < 15.0, tl.math.log2(1 + tl.math.exp2(x)), x) + return tl.inline_asm_elementwise( + asm=asm_str, + constraints=constraints_str, + pack=NUM_REG, + args=[ + x, + ], + dtype=tl.float32, + is_pure=True, + ) diff --git a/tests/models/test_modeling_stickbreaking_attn.py b/tests/models/test_modeling_stickbreaking_attn.py new file mode 100644 index 000000000..a56ae2672 --- /dev/null +++ b/tests/models/test_modeling_stickbreaking_attn.py @@ -0,0 +1,55 @@ +import pytest +import torch + +from fla.models import StickBreakingAttentionConfig + +from .test_modeling_base import run_test_generation, run_test_model_forward_backward + + +# =================================================================================== +# Test for Modeling (Forward/Backward Pass) +# =================================================================================== +@pytest.mark.parametrize( + ['L', 'B', 'T', 'H', 'D', 'use_l2warp', 'dtype'], + [ + pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-use_l2warp{}-{}".format(*test)) + for test in [ + (4, 4, 1024, 4, 64, True, torch.bfloat16), + (4, 4, 1024, 4, 64, False, torch.bfloat16), + (4, 4, 1024, 4, 128, False, torch.bfloat16), + ] + ], +) +def test_modeling( + L: int, + B: int, + T: int, + H: int, + D: int, + use_l2warp: bool, + dtype: torch.dtype, +): + run_test_model_forward_backward(L, B, T, H, D, StickBreakingAttentionConfig, use_l2warp=use_l2warp, dtype=dtype) + + +# =================================================================================== +# Test for Generation +# =================================================================================== +@pytest.mark.parametrize( + ['L', 'B', 'T', 'H', 'D', 'dtype'], + [ + pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-{}".format(*test)) + for test in [ + (2, 4, 2000, 8, 64, torch.float16), + ] + ], +) +def test_generation( + L: int, + B: int, + T: int, + H: int, + D: int, + dtype: torch.dtype, +): + run_test_generation(L, B, T, H, D, StickBreakingAttentionConfig, dtype) diff --git a/tests/models/test_modeling_utils.py b/tests/models/test_modeling_utils.py index e5df066fe..f25f79ada 100644 --- a/tests/models/test_modeling_utils.py +++ b/tests/models/test_modeling_utils.py @@ -23,7 +23,7 @@ GENERATION_UNSUPPORTED = [ "ABCConfig", "LinearAttentionConfig", "LightNetConfig", "Mamba2Config", "MambaConfig", "NSAConfig", "SambaConfig", "RWKV6Config", "RWKV7Config", - "DeltaFormerConfig", + "DeltaFormerConfig", "StickBreakingAttentionConfig", ] diff --git a/tests/ops/test_stickbreaking_attn.py b/tests/ops/test_stickbreaking_attn.py new file mode 100644 index 000000000..056b375cc --- /dev/null +++ b/tests/ops/test_stickbreaking_attn.py @@ -0,0 +1,138 @@ +import math + +import pytest +import torch + +from fla.ops.stickbreaking_attn import naive_stickbreaking_attn, parallel_stickbreaking_attn +from fla.utils import assert_close, device, is_intel_alchemist + + +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-{}".format(*test)) + for test in [ + (2, 128, 2, 64, torch.float16), + (1, 256, 4, 64, torch.float16), + (2, 512, 4, 64, torch.float16), + (4, 1024, 4, 128, torch.float16), + ] + ], +) +@pytest.mark.skipif( + is_intel_alchemist, + reason="Skipping test on Intel Alchemist due to known issues with SRAM.", +) +def test_stickbreaking_attn( + B: int, + T: int, + H: int, + D: int, + dtype: torch.dtype, +): + torch.manual_seed(42) + + q = torch.randn((B, T, H, D), dtype=dtype, device=device).requires_grad_(True) + k = torch.randn((B, T, H, D), dtype=dtype, device=device).requires_grad_(True) + v = torch.randn((B, T, H, D), dtype=dtype, device=device).requires_grad_(True) + + do = torch.randn((B, T, H, D), dtype=dtype, device=device) + dr = torch.randn((B, T, H), dtype=dtype, device=device) + + scale = 1.0 / math.sqrt(D) + + # Reference (naive) + ref_o, ref_rem = naive_stickbreaking_attn(q, k, v, scale) + (ref_o * do).sum().backward(retain_graph=True) + (ref_rem * dr).sum().backward() + ref_dq, q.grad = q.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dv, v.grad = v.grad.clone(), None + + # Triton fused + tri_o, tri_rem = parallel_stickbreaking_attn(q, k, v, scale=scale) + (tri_o * do).sum().backward(retain_graph=True) + (tri_rem * dr).sum().backward() + tri_dq, q.grad = q.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dv, v.grad = v.grad.clone(), None + + # Compare + assert_close(" o", ref_o, tri_o, 0.008) + assert_close("rem", ref_rem, tri_rem, 0.02) + assert_close("dq", ref_dq, tri_dq, 0.02) + assert_close("dk", ref_dk, tri_dk, 0.02) + assert_close("dv", ref_dv, tri_dv, 0.02) + + +@pytest.mark.parametrize( + ('H', 'D', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="H{}-D{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (2, 64, [0, 63], torch.float16), + (4, 64, [0, 256, 500, 1000], torch.float16), + (4, 128, [0, 15, 100, 300, 1200, 2000], torch.float16), + (2, 128, [0, 100, 123, 300, 500, 800, 1000, 1500, 2048], torch.float16), + ] + ], +) +@pytest.mark.skipif( + is_intel_alchemist, + reason="Skipping test on Intel Alchemist due to known issues with SRAM.", +) +def test_stickbreaking_attn_varlen( + H: int, + D: int, + cu_seqlens: list[int], + dtype: torch.dtype, +): + torch.manual_seed(42) + + T = cu_seqlens[-1] + num_chunks = len(cu_seqlens) - 1 + cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + + q = torch.randn((1, T, H, D), dtype=dtype, device=device).requires_grad_(True) + k = torch.randn((1, T, H, D), dtype=dtype, device=device).requires_grad_(True) + v = torch.randn((1, T, H, D), dtype=dtype, device=device).requires_grad_(True) + + do = torch.randn_like(q) + dr = torch.randn((1, T, H), dtype=dtype, device=device) + + scale = 1.0 / math.sqrt(D) + + ref_os = [] + ref_rems = [] + for idx in range(num_chunks): + start, end = cu_seqlens[idx], cu_seqlens[idx + 1] + ref_o_chunk, ref_rem_chunk = naive_stickbreaking_attn( + q[:, start:end], + k[:, start:end], + v[:, start:end], + scale, + ) + ref_os.append(ref_o_chunk) + ref_rems.append(ref_rem_chunk) + + ref_o = torch.cat(ref_os, dim=1) + ref_rem = torch.cat(ref_rems, dim=1) + + (ref_o * do).sum().backward(retain_graph=True) + (ref_rem * dr).sum().backward() + ref_dq, q.grad = q.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dv, v.grad = v.grad.clone(), None + + tri_o, tri_rem = parallel_stickbreaking_attn(q, k, v, scale=scale, cu_seqlens=cu_seqlens_tensor) + (tri_o * do).sum().backward(retain_graph=True) + (tri_rem * dr).sum().backward() + tri_dq, q.grad = q.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dv, v.grad = v.grad.clone(), None + + assert_close("o", ref_o, tri_o, 0.008) + assert_close("rem", ref_rem, tri_rem, 0.02) + assert_close("dq", ref_dq, tri_dq, 0.02) + assert_close("dk", ref_dk, tri_dk, 0.02) + assert_close("dv", ref_dv, tri_dv, 0.02)