From db7411ad698d8c5efc72d4412b6a393dbd948f3f Mon Sep 17 00:00:00 2001 From: Nathancgy Date: Tue, 23 Sep 2025 15:25:04 +0000 Subject: [PATCH 01/10] [Stick-Breaking Attention] Add Model --- fla/__init__.py | 6 +- fla/layers/__init__.py | 2 + fla/layers/stickbreaking_attn.py | 112 +++ fla/models/__init__.py | 6 + fla/models/stickbreaking_attn/__init__.py | 16 + .../configuration_stickbreaking_attn.py | 85 +++ .../modeling_stickbreaking_attn.py | 344 +++++++++ fla/ops/__init__.py | 3 + fla/ops/stickbreaking_attn/__init__.py | 9 + fla/ops/stickbreaking_attn/naive.py | 62 ++ fla/ops/stickbreaking_attn/parallel.py | 718 ++++++++++++++++++ .../test_modeling_stickbreaking_attn.py | 57 ++ tests/models/test_modeling_utils.py | 2 +- tests/ops/test_stickbreaking_attn.py | 65 ++ 14 files changed, 1485 insertions(+), 2 deletions(-) create mode 100644 fla/layers/stickbreaking_attn.py create mode 100644 fla/models/stickbreaking_attn/__init__.py create mode 100644 fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py create mode 100644 fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py create mode 100644 fla/ops/stickbreaking_attn/__init__.py create mode 100644 fla/ops/stickbreaking_attn/naive.py create mode 100644 fla/ops/stickbreaking_attn/parallel.py create mode 100644 tests/models/test_modeling_stickbreaking_attn.py create mode 100644 tests/ops/test_stickbreaking_attn.py diff --git a/fla/__init__.py b/fla/__init__.py index 3da00f9af..2651a5a47 100644 --- a/fla/__init__.py +++ b/fla/__init__.py @@ -26,7 +26,8 @@ ReBasedLinearAttention, RodimusAttention, RWKV6Attention, - RWKV7Attention + RWKV7Attention, + StickBreakingAttention ) from fla.models import ( ABCForCausalLM, @@ -75,6 +76,8 @@ RWKV6Model, RWKV7ForCausalLM, RWKV7Model, + StickBreakingAttentionForCausalLM, + StickBreakingAttentionModel, TransformerForCausalLM, TransformerModel ) @@ -106,6 +109,7 @@ 'RodimusAttention', 'RodimusForCausalLM', 'RodimusModel', 'RWKV6Attention', 'RWKV6ForCausalLM', 'RWKV6Model', 'RWKV7Attention', 'RWKV7ForCausalLM', 'RWKV7Model', + 'StickBreakingAttention', 'StickBreakingAttentionForCausalLM', 'StickBreakingAttentionModel', ] __version__ = '0.3.2' diff --git a/fla/layers/__init__.py b/fla/layers/__init__.py index 3527a387e..218b67600 100644 --- a/fla/layers/__init__.py +++ b/fla/layers/__init__.py @@ -30,6 +30,7 @@ from .rodimus import RodimusAttention, SlidingWindowSharedKeyAttention from .rwkv6 import RWKV6Attention from .rwkv7 import RWKV7Attention +from .stickbreaking_attn import StickBreakingAttention __all__ = [ 'ABCAttention', @@ -60,6 +61,7 @@ 'RodimusAttention', 'RWKV6Attention', 'RWKV7Attention', + 'StickBreakingAttention', 'SlidingWindowSharedKeyAttention', 'DeltaFormerAttention', ] diff --git a/fla/layers/stickbreaking_attn.py b/fla/layers/stickbreaking_attn.py new file mode 100644 index 000000000..7b8fffbaf --- /dev/null +++ b/fla/layers/stickbreaking_attn.py @@ -0,0 +1,112 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.nn as nn +from einops import rearrange +from transformers.utils import logging + +from fla.modules import RMSNorm +from fla.ops.stickbreaking_attn import sb_attn + +if TYPE_CHECKING: + from fla.models.utils import Cache + + +logger = logging.get_logger(__name__) + + +class StickBreakingAttention(nn.Module): + + def __init__( + self, + hidden_size: int = 2048, + num_heads: int = 32, + num_kv_heads: Optional[int] = None, + qkv_bias: bool = False, + qk_norm: bool = False, + window_size: Optional[int] = None, + rope_theta: Optional[float] = None, # sba doesn't use RoPE + max_position_embeddings: Optional[int] = None, + layer_idx: int | None = None, + ): + super().__init__() + + if sb_attn is None: + raise ImportError( + "StickBreakingAttention kernels are not available. Ensure Triton is installed and ops are importable." + ) + + self.hidden_size = hidden_size + self.num_heads = num_heads + if num_kv_heads is None: + self.num_kv_heads = self.num_heads + else: + self.num_kv_heads = num_kv_heads + self.num_kv_groups = self.num_heads // self.num_kv_heads + self.head_dim = self.hidden_size // self.num_heads + self.kv_dim = self.num_kv_heads * self.head_dim + self.qkv_bias = qkv_bias + self.qk_norm = qk_norm + + self.window_size = window_size + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.layer_idx = layer_idx + + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias) + self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias) + self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + + if qk_norm: + self.q_norm = RMSNorm(self.head_dim) + self.k_norm = RMSNorm(self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + attend_current: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + if use_cache: + warnings.warn( + "StickBreakingAttention does not support KV cache yet; falling back to use_cache=False.") + use_cache = False + + batch_size, q_len, _ = hidden_states.size() + + q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) + k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) + v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) + + if self.qk_norm: + q, k = self.q_norm(q), self.k_norm(k) + + inv_temp = 1.0 / math.sqrt(self.head_dim) + + cu_seqlens = kwargs.get('cu_seqlens', None) + o, _rem = sb_attn(q, k, v, inv_temp=inv_temp, attend_current=attend_current, cu_seqlens=cu_seqlens) + o = o.reshape(batch_size, q_len, -1) + o = self.o_proj(o) + + attentions = None + + return o, attentions, past_key_values diff --git a/fla/models/__init__.py b/fla/models/__init__.py index 902ccbf81..c19ff6de4 100644 --- a/fla/models/__init__.py +++ b/fla/models/__init__.py @@ -31,6 +31,11 @@ from fla.models.rwkv6 import RWKV6Config, RWKV6ForCausalLM, RWKV6Model from fla.models.rwkv7 import RWKV7Config, RWKV7ForCausalLM, RWKV7Model from fla.models.samba import SambaConfig, SambaForCausalLM, SambaModel +from fla.models.stickbreaking_attn import ( + StickBreakingAttentionConfig, + StickBreakingAttentionForCausalLM, + StickBreakingAttentionModel +) from fla.models.transformer import TransformerConfig, TransformerForCausalLM, TransformerModel __all__ = [ @@ -62,4 +67,5 @@ 'RWKV7Config', 'RWKV7ForCausalLM', 'RWKV7Model', 'SambaConfig', 'SambaForCausalLM', 'SambaModel', 'TransformerConfig', 'TransformerForCausalLM', 'TransformerModel', + 'StickBreakingAttentionConfig', 'StickBreakingAttentionForCausalLM', 'StickBreakingAttentionModel', ] diff --git a/fla/models/stickbreaking_attn/__init__.py b/fla/models/stickbreaking_attn/__init__.py new file mode 100644 index 000000000..35753f62e --- /dev/null +++ b/fla/models/stickbreaking_attn/__init__.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.stickbreaking_attn.configuration_stickbreaking_attn import StickBreakingAttentionConfig +from fla.models.stickbreaking_attn.modeling_stickbreaking_attn import ( + StickBreakingAttentionForCausalLM, + StickBreakingAttentionModel +) + +AutoConfig.register(StickBreakingAttentionConfig.model_type, StickBreakingAttentionConfig, exist_ok=True) +AutoModel.register(StickBreakingAttentionConfig, StickBreakingAttentionModel, exist_ok=True) +AutoModelForCausalLM.register(StickBreakingAttentionConfig, StickBreakingAttentionForCausalLM, exist_ok=True) + + +__all__ = ['StickBreakingAttentionConfig', 'StickBreakingAttentionForCausalLM', 'StickBreakingAttentionModel'] diff --git a/fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py b/fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py new file mode 100644 index 000000000..141db9e0c --- /dev/null +++ b/fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- + +import warnings +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class StickBreakingAttentionConfig(PretrainedConfig): + + model_type = 'stickbreaking_attn' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + hidden_size: int = 2048, + num_hidden_layers: int = 24, + num_heads: int = 32, + num_kv_heads: Optional[int] = None, + qkv_bias: bool = False, + qk_norm: bool = False, + window_size: Optional[int] = None, + max_position_embeddings: int = 2048, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + initializer_range: float = 0.02, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: Optional[int] = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + fuse_linear_cross_entropy: bool = False, + use_l2warp: bool = False, + vocab_size: int = 32000, + **kwargs, + ): + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.qkv_bias = qkv_bias + self.qk_norm = qk_norm + self.window_size = window_size + self.max_position_embeddings = max_position_embeddings + + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + + self.initializer_range = initializer_range + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_cache = use_cache + + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.fuse_linear_cross_entropy = fuse_linear_cross_entropy + self.use_l2warp = use_l2warp + self.vocab_size = vocab_size + + if fuse_cross_entropy and fuse_linear_cross_entropy: + raise ValueError( + "`fuse_cross_entropy` and `fuse_linear_cross_entropy` cannot be True at the same time." + ) + if fuse_linear_cross_entropy: + warnings.warn( + "`fuse_linear_cross_entropy` is enabled, which can improves memory efficiency " + "at the potential cost of reduced precision. " + "If you observe issues like loss divergence, consider disabling this setting." + ) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py b/fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py new file mode 100644 index 000000000..9fa4ff231 --- /dev/null +++ b/fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py @@ -0,0 +1,344 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.stickbreaking_attn import StickBreakingAttention +from fla.models.stickbreaking_attn.configuration_stickbreaking_attn import StickBreakingAttentionConfig +from fla.models.utils import Cache, FLAGenerationMixin +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as SBAttnMLP +from fla.modules import RMSNorm +from fla.modules.l2warp import l2_warp + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + +try: + from transformers.modeling_layers import GradientCheckpointingLayer +except ImportError: + from fla.models.modeling_layers import GradientCheckpointingLayer + +logger = logging.get_logger(__name__) + + +class StickBreakingAttentionBlock(GradientCheckpointingLayer): + + def __init__(self, config: StickBreakingAttentionConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.attn = StickBreakingAttention( + hidden_size=config.hidden_size, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + qkv_bias=config.qkv_bias, + qk_norm=config.qk_norm, + window_size=config.window_size, + rope_theta=None, + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = SBAttnMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs: Unpack[Any] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attentions,) + + if use_cache: + outputs += (past_key_values,) + + return outputs + + +class StickBreakingAttentionPreTrainedModel(PreTrainedModel): + + config_class = StickBreakingAttentionConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['StickBreakingAttentionBlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = False, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if rescale_prenorm_residual: + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class StickBreakingAttentionModel(StickBreakingAttentionPreTrainedModel): + + def __init__( + self, + config: StickBreakingAttentionConfig + ) -> StickBreakingAttentionModel: + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + layers = [] + for layer_idx in range(config.num_hidden_layers): + layers.append(StickBreakingAttentionBlock(config, layer_idx)) + self.layers = nn.ModuleList(layers) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Any] + ) -> Union[Tuple, CausalLMOutputWithPast]: + if output_attentions: + warnings.warn( + "`sba` does not support output attention weights now, so `output_attentions` is set to `False`." + ) + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + hidden_states = inputs_embeds + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + next_cache = None + + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class StickBreakingAttentionForCausalLM(StickBreakingAttentionPreTrainedModel, FLAGenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = StickBreakingAttentionModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Any] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + + logits = None if self.config.fuse_linear_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:]) + + loss = None + if labels is not None: + if getattr(self, 'criterion', None) is None: + if self.config.fuse_linear_cross_entropy: + criterion = FusedLinearCrossEntropyLoss(use_l2warp=self.config.use_l2warp) + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if self.config.fuse_linear_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + loss = l2_warp(loss, logits) if self.config.use_l2warp else loss + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla/ops/__init__.py b/fla/ops/__init__.py index e45884b6f..a1eeb4163 100644 --- a/fla/ops/__init__.py +++ b/fla/ops/__init__.py @@ -26,6 +26,8 @@ from .rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6 from .rwkv7 import chunk_rwkv7, fused_recurrent_rwkv7 from .simple_gla import chunk_simple_gla, fused_chunk_simple_gla, fused_recurrent_simple_gla, parallel_simple_gla +from .stickbreaking_attn.naive import sb_attn_naive +from .stickbreaking_attn.parallel import sb_attn __all__ = [ 'chunk_abc', @@ -50,4 +52,5 @@ 'chunk_rwkv6', 'fused_recurrent_rwkv6', 'chunk_rwkv7', 'fused_recurrent_rwkv7', 'chunk_simple_gla', 'fused_chunk_simple_gla', 'fused_recurrent_simple_gla', 'parallel_simple_gla', + 'sb_attn', 'sb_attn_naive', ] diff --git a/fla/ops/stickbreaking_attn/__init__.py b/fla/ops/stickbreaking_attn/__init__.py new file mode 100644 index 000000000..ac1e4130c --- /dev/null +++ b/fla/ops/stickbreaking_attn/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .naive import sb_attn_naive +from .parallel import sb_attn + +__all__ = [ + 'sb_attn', + 'sb_attn_naive', +] diff --git a/fla/ops/stickbreaking_attn/naive.py b/fla/ops/stickbreaking_attn/naive.py new file mode 100644 index 000000000..f02dfd4ca --- /dev/null +++ b/fla/ops/stickbreaking_attn/naive.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Tuple + +import torch + + +def _tril_mask(T: int, strict: bool = True, device=None) -> torch.Tensor: + i = torch.arange(T, device=device).view(1, 1, T, 1) + j = torch.arange(T, device=device).view(1, 1, 1, T) + return (j < i) if strict else (j <= i) + + +def sb_attn_naive( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + inv_temp: float, + attend_current: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Naive stick-breaking attention reference implementation. + + Args: + q, k, v: [B, T, H, D] + inv_temp: inverse temperature (1/sqrt(D)) + attend_current: include diagonal when computing weights + + Returns: + o: [B, T, H, D] + rem: [B, T, H] (1 - sum of attention up to t) + """ + B, T, H, D = q.shape + orig_dtype = q.dtype + + logits = torch.einsum('bthd,bshd->bhts', q, k) * inv_temp + logits = logits.float() + + if attend_current: + mask = torch.ones(T, T, device=q.device).triu(1).bool() # exclude diagonal + else: + mask = torch.ones(T, T, device=q.device).triu(0).bool() # include diagonal + mask = mask.unsqueeze(0).unsqueeze(0) # [1, 1, T, T] + + log_z = torch.nn.functional.logsigmoid(logits).masked_fill(mask, -1e5).to(orig_dtype) + log_beta = torch.nn.functional.logsigmoid(-logits).masked_fill(mask, 0).to(orig_dtype) + + cum_weight = torch.ones(T, T, device=q.device).tril(-1) + + re_cum_log_beta = torch.einsum("bhij,jk->bhik", log_beta, cum_weight.to(log_beta)) + log_att = log_z + re_cum_log_beta + att = log_att.exp() + o = torch.einsum('bhts,bshd->bthd', att, v) + rem = 1 - att.sum(dim=-1).transpose(1, 2) + + return o.to(orig_dtype), rem.to(orig_dtype) + + +__all__ = [ + 'sb_attn_naive', +] diff --git a/fla/ops/stickbreaking_attn/parallel.py b/fla/ops/stickbreaking_attn/parallel.py new file mode 100644 index 000000000..dadb95673 --- /dev/null +++ b/fla/ops/stickbreaking_attn/parallel.py @@ -0,0 +1,718 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import math +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.index import prepare_chunk_indices + +ALLOW_TF32 = True +inv_log2 = 1.0 / math.log(2.0) + + +def _get_configs(): + return [triton.Config({}, num_stages=s, num_warps=w) for s in [4] for w in [4]] + + +@triton.autotune(configs=_get_configs(), key=["token_size", "head_size"]) +@triton.jit +def stickbreaking_attn_fwd_kernel( + Q_ptr, + K_ptr, + V_ptr, + O_ptr, + R_ptr, + A_ptr, + CU_ptr, + CI_ptr, + logit_scale: tl.constexpr, + attend_current: tl.constexpr, + batch_size, + token_size, + head_size: tl.constexpr, + num_heads: tl.constexpr, + BLOCK_D: tl.constexpr, + NO_D_MASK: tl.constexpr, + NO_M_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, + ALLOW_TF32: tl.constexpr, + inv_log2: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + no_grad: tl.constexpr = False, + acc_dtype: tl.constexpr = tl.float32, + is_compiling: tl.constexpr = False, + IS_VARLEN: tl.constexpr = False, +): + tl.static_assert(BLOCK_M % BLOCK_N == 0) + batch_id = 0 if IS_VARLEN else tl.program_id(0) + head_pid = tl.program_id(1) + prog_id = tl.program_id(2) + tl.num_programs(2) + if IS_VARLEN: + i_n = tl.load(CI_ptr + prog_id * 2).to(tl.int32) + seq_block_id = tl.load(CI_ptr + prog_id * 2 + 1).to(tl.int32) + bos = tl.load(CU_ptr + i_n).to(tl.int32) + eos = tl.load(CU_ptr + i_n + 1).to(tl.int32) + seq_length = eos - bos + else: + bos = 0 + seq_block_id = prog_id + seq_length = token_size + qk_scale = inv_log2 * logit_scale + M_range = tl.arange(0, BLOCK_M) + N_range = tl.arange(0, BLOCK_N) + D_range = tl.arange(0, BLOCK_D) + D_mask = D_range < head_size + cm = tl.where(N_range[:, None] >= N_range[None, :], 1.0, 0.0).to(Q_ptr.type.element_ty) + + head_id = head_pid + seq_prog_id = seq_block_id + batch_offset = batch_id * token_size + Q_head_seq_ptr = Q_ptr + ((batch_offset + bos) * num_heads + head_id) * head_size + K_head_seq_ptr = K_ptr + ((batch_offset + bos) * num_heads + head_id) * head_size + V_head_seq_ptr = V_ptr + ((batch_offset + bos) * num_heads + head_id) * head_size + O_head_seq_ptr = O_ptr + ((batch_offset + bos) * num_heads + head_id) * head_size + R_head_seq_ptr = R_ptr + ((batch_offset + bos) * num_heads + head_id) + A_head_seq_ptr = A_ptr + ((batch_offset + bos) * num_heads + head_id) + + stickbreaking_attn_fwd_one_row_kernel( + seq_prog_id, + seq_length, + qk_scale, + M_range, + N_range, + D_range, + D_mask, + cm, + Q_head_seq_ptr, + K_head_seq_ptr, + V_head_seq_ptr, + O_head_seq_ptr, + R_head_seq_ptr, + A_head_seq_ptr, + head_size, + num_heads, + BLOCK_D, + NO_D_MASK, + NO_M_MASK, + NO_N_MASK, + ALLOW_TF32, + BLOCK_M, + BLOCK_N, + no_grad, + acc_dtype, + False, + attend_current=attend_current, + is_compiling=is_compiling, + ) + + +@triton.jit +def stickbreaking_attn_fwd_one_row_kernel( + seq_block_id, + seq_length, + qk_scale, + M_range, + N_range, + D_range, + D_mask, + cm, + Q_head_seq_ptr, + K_head_seq_ptr, + V_head_seq_ptr, + O_head_seq_ptr, + R_head_seq_ptr, + A_head_seq_ptr, + head_size: tl.constexpr, + num_heads: tl.constexpr, + BLOCK_D: tl.constexpr, + NO_D_MASK: tl.constexpr, + NO_M_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + no_grad: tl.constexpr = False, + acc_dtype: tl.constexpr = tl.float32, + return_attention: tl.constexpr = False, + attend_current: tl.constexpr = False, + is_compiling: tl.constexpr = False, +): + block_start_offset = BLOCK_M * seq_block_id + M_blk_idxs = block_start_offset + M_range + M_mask = M_blk_idxs < seq_length + N_blk_idxs_start = block_start_offset + BLOCK_M + N_blk_idxs = N_blk_idxs_start + N_range + + Q_blk_ptrs = Q_head_seq_ptr + ( + (num_heads * head_size) * M_blk_idxs[:, None] + 1 * D_range[None, :] + ) + K_blk_ptrs = K_head_seq_ptr + ( + (num_heads * head_size) * N_blk_idxs[:, None] + 1 * D_range[None, :] + ) + V_blk_ptrs = V_head_seq_ptr + ( + (num_heads * head_size) * N_blk_idxs[:, None] + 1 * D_range[None, :] + ) + O_blk_ptrs = O_head_seq_ptr + ( + (num_heads * head_size) * M_blk_idxs[:, None] + 1 * D_range[None, :] + ) + R_blk_ptrs = R_head_seq_ptr + num_heads * M_blk_idxs + A_blk_ptrs = A_head_seq_ptr + num_heads * M_blk_idxs + + if NO_D_MASK: + if NO_M_MASK: + q = tl.load(Q_blk_ptrs) + else: + q = tl.load(Q_blk_ptrs, mask=M_mask[:, None], other=0.0) + else: + q = tl.load(Q_blk_ptrs, mask=M_mask[:, None] & D_mask[None, :], other=0.0) + + iters = N_blk_idxs_start // BLOCK_N + neg_log_acc = tl.zeros([BLOCK_M], dtype=acc_dtype) + acc = tl.zeros([BLOCK_M, BLOCK_D], dtype=acc_dtype) + + for i in range(iters): + N_blk_idxs -= BLOCK_N + N_blk_idxs_start -= BLOCK_N + K_blk_ptrs -= BLOCK_N * (num_heads * head_size) + V_blk_ptrs -= BLOCK_N * (num_heads * head_size) + + N_mask = N_blk_idxs < seq_length + k, v = load_kv( + K_blk_ptrs, + V_blk_ptrs, + N_mask=N_mask, + NO_N_MASK=N_blk_idxs_start + BLOCK_N - 1 < seq_length, + D_mask=D_mask, + NO_D_MASK=NO_D_MASK, + ) + on_band = i < BLOCK_M // BLOCK_N + p, _log_om_beta, neg_log_acc = compute_block( + q, + k, + qk_scale, + neg_log_acc, + M_blk_idxs, + N_blk_idxs, + cm, + on_band, + ALLOW_TF32, + backward=False, + attend_current=attend_current, + is_compiling=is_compiling, + use_cumsum=False, + ) + acc = tl.dot(p.to(v.dtype), v, acc, allow_tf32=ALLOW_TF32) + + if NO_M_MASK: + tl.store(R_blk_ptrs, tl.math.exp2(neg_log_acc)) + tl.store(A_blk_ptrs, neg_log_acc.to(A_head_seq_ptr.type.element_ty)) + else: + tl.store(R_blk_ptrs, tl.math.exp2(neg_log_acc), mask=M_mask) + tl.store(A_blk_ptrs, neg_log_acc.to(A_head_seq_ptr.type.element_ty), mask=M_mask) + if NO_D_MASK: + tl.store(O_blk_ptrs, acc.to(O_head_seq_ptr.type.element_ty), mask=M_mask[:, None]) + else: + tl.store(O_blk_ptrs, acc.to(O_head_seq_ptr.type.element_ty), mask=M_mask[:, None] & D_mask[None, :]) + + +def _get_bwd_configs(): + return [triton.Config({}, num_stages=s, num_warps=w) for s in [8] for w in [4]] + + +@triton.autotune(configs=_get_bwd_configs(), key=["token_size", "head_size"]) +@triton.jit() +def stickbreaking_attn_bwd_kernel( + DO_ptr, + DR_ptr, + A_ptr, + Q_ptr, + K_ptr, + V_ptr, + DQ_ptr, + DK_ptr, + DV_ptr, + CU_ptr, + CI_ptr, + logit_scale, + batch_size, + token_size, + head_size: tl.constexpr, + num_heads: tl.constexpr, + BLOCK_D: tl.constexpr, + NO_D_MASK: tl.constexpr, + NO_M_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, + ALLOW_TF32: tl.constexpr, + inv_log2: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + acc_dtype: tl.constexpr = tl.float32, + is_compiling: tl.constexpr = False, + attend_current: tl.constexpr = False, + IS_VARLEN: tl.constexpr = False, +): + tl.static_assert(BLOCK_M % BLOCK_N == 0) + batch_id = 0 if IS_VARLEN else tl.program_id(0) + head_pid = tl.program_id(1) + prog_id = tl.program_id(2) + qk_scale = inv_log2 * logit_scale + M_range = tl.arange(0, BLOCK_M) + N_range = tl.arange(0, BLOCK_N) + D_range = tl.arange(0, BLOCK_D) + D_mask = D_range < head_size + cm = tl.where(N_range[:, None] >= N_range[None, :], 1.0, 0.0).to(Q_ptr.type.element_ty) + + if IS_VARLEN: + i_n = tl.load(CI_ptr + prog_id * 2).to(tl.int32) + seq_block_id = tl.load(CI_ptr + prog_id * 2 + 1).to(tl.int32) + bos = tl.load(CU_ptr + i_n).to(tl.int32) + eos = tl.load(CU_ptr + i_n + 1).to(tl.int32) + seq_length = eos - bos + else: + bos = 0 + seq_block_id = prog_id + seq_length = token_size + + head_id = head_pid + seq_prog_id = seq_block_id + + batch_offset = batch_id * token_size + DO_head_seq_ptr = DO_ptr + ((batch_offset + bos) * num_heads + head_id) * head_size + DR_head_seq_ptr = DR_ptr + ((batch_offset + bos) * num_heads + head_id) + A_head_seq_ptr = A_ptr + ((batch_offset + bos) * num_heads + head_id) + Q_head_seq_ptr = Q_ptr + ((batch_offset + bos) * num_heads + head_id) * head_size + K_head_seq_ptr = K_ptr + ((batch_offset + bos) * num_heads + head_id) * head_size + V_head_seq_ptr = V_ptr + ((batch_offset + bos) * num_heads + head_id) * head_size + DQ_head_seq_ptr = DQ_ptr + ((batch_offset + bos) * num_heads + head_id) * head_size + DK_head_seq_ptr = DK_ptr + ( + seq_prog_id * batch_size * token_size * num_heads + (batch_offset + bos) * num_heads + head_id + ) * head_size + DV_head_seq_ptr = DV_ptr + ( + seq_prog_id * batch_size * token_size * num_heads + (batch_offset + bos) * num_heads + head_id + ) * head_size + + stickbreaking_attn_bwd_one_row_kernel( + seq_prog_id, + seq_length, + qk_scale, + M_range, + N_range, + D_range, + D_mask, + cm, + DO_head_seq_ptr, + DR_head_seq_ptr, + A_head_seq_ptr, + Q_head_seq_ptr, + K_head_seq_ptr, + V_head_seq_ptr, + DQ_head_seq_ptr, + DK_head_seq_ptr, + DV_head_seq_ptr, + logit_scale, + head_size, + num_heads, + BLOCK_D, + NO_D_MASK, + NO_M_MASK, + NO_N_MASK, + ALLOW_TF32, + BLOCK_M, + BLOCK_N, + acc_dtype, + is_compiling=is_compiling, + attend_current=attend_current, + ) + + +@triton.jit +def load_kv(K_blk_ptrs, V_blk_ptrs, N_mask, NO_N_MASK, D_mask, NO_D_MASK: tl.constexpr): + if NO_D_MASK: + if NO_N_MASK: + k = tl.load(K_blk_ptrs) + v = tl.load(V_blk_ptrs) + else: + k = tl.load(K_blk_ptrs, mask=N_mask[:, None]) + v = tl.load(V_blk_ptrs, mask=N_mask[:, None]) + else: + mask = N_mask[:, None] & D_mask[None, :] + k = tl.load(K_blk_ptrs, mask=mask) + v = tl.load(V_blk_ptrs, mask=mask) + return k, v + + +@triton.jit +def compute_block( + q, + k, + qk_scale, + neg_log_acc, + M_blk_idxs, + N_blk_idxs, + cm, + on_band: tl.constexpr, + ALLOW_TF32: tl.constexpr, + backward: tl.constexpr, + attend_current: tl.constexpr = False, + use_cumsum: tl.constexpr = False, + is_compiling: tl.constexpr = False, +): + qk = tl.dot(q, tl.trans(k), allow_tf32=ALLOW_TF32) * qk_scale + log_om_beta = -softplus(qk, is_compiling=is_compiling) + + if on_band: + if attend_current: + block_mask = M_blk_idxs[:, None] >= N_blk_idxs[None, :] + else: + block_mask = M_blk_idxs[:, None] > N_blk_idxs[None, :] + log_om_beta = tl.where(block_mask, log_om_beta, 0.0) + if backward: + neg_log_acc -= tl.sum(log_om_beta, axis=1) + log_p = qk + neg_log_acc[:, None] + if use_cumsum: + log_p += tl.cumsum(log_om_beta.to(q.dtype), axis=1, reverse=True) + else: + log_p = tl.dot(log_om_beta.to(q.dtype), cm, acc=log_p, allow_tf32=ALLOW_TF32) + p = tl.math.exp2(log_p) + p = tl.where(block_mask, p, 0.0) + else: + if backward: + neg_log_acc -= tl.sum(log_om_beta, axis=1) + log_p = qk + neg_log_acc[:, None] + if use_cumsum: + log_p += tl.cumsum(log_om_beta.to(q.dtype), axis=1, reverse=True) + else: + log_p = tl.dot(log_om_beta.to(q.dtype), cm, acc=log_p, allow_tf32=ALLOW_TF32) + p = tl.math.exp2(log_p) + if not backward: + neg_log_acc += tl.sum(log_om_beta, axis=1) + return p, log_om_beta, neg_log_acc + + +@triton.jit +def softplus(x, is_compiling: tl.constexpr = False): + return tl.where(x < 15.0, tl.math.log2(1 + tl.math.exp2(x)), x) + + +@triton.jit +def stickbreaking_attn_bwd_one_row_kernel( + seq_prog_id, + seq_length, + qk_scale, + M_range, + N_range, + D_range, + D_mask, + cm, + DO_head_seq_ptr, + DR_head_seq_ptr, + A_head_seq_ptr, + Q_head_seq_ptr, + K_head_seq_ptr, + V_head_seq_ptr, + DQ_head_seq_ptr, + DK_head_seq_ptr, + DV_head_seq_ptr, + logit_scale, + head_size: tl.constexpr, + num_heads: tl.constexpr, + BLOCK_D: tl.constexpr, + NO_D_MASK: tl.constexpr, + NO_M_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + acc_dtype: tl.constexpr = tl.float32, + is_compiling: tl.constexpr = False, + attend_current: tl.constexpr = False, +): + block_start_offset = BLOCK_M * seq_prog_id + M_blk_idxs = block_start_offset + M_range + M_mask = M_blk_idxs < seq_length + + N_blk_idxs_start = 0 + N_blk_idxs = N_blk_idxs_start + N_range + + DO_blk_ptrs = DO_head_seq_ptr + ( + (num_heads * head_size) * M_blk_idxs[:, None] + 1 * D_range[None, :] + ) + K_blk_ptrs = K_head_seq_ptr + ( + (num_heads * head_size) * N_blk_idxs[:, None] + 1 * D_range[None, :] + ) + Q_blk_ptrs = Q_head_seq_ptr + ( + (num_heads * head_size) * M_blk_idxs[:, None] + 1 * D_range[None, :] + ) + V_blk_ptrs = V_head_seq_ptr + ( + (num_heads * head_size) * N_blk_idxs[:, None] + 1 * D_range[None, :] + ) + A_blk_ptrs = A_head_seq_ptr + num_heads * M_blk_idxs + DQ_blk_ptrs = DQ_head_seq_ptr + ( + (num_heads * head_size) * M_blk_idxs[:, None] + 1 * D_range[None, :] + ) + DK_blk_ptrs = DK_head_seq_ptr + ( + (num_heads * head_size) * N_blk_idxs[:, None] + 1 * D_range[None, :] + ) + DV_blk_ptrs = DV_head_seq_ptr + ( + (num_heads * head_size) * N_blk_idxs[:, None] + 1 * D_range[None, :] + ) + DR_blk_ptrs = DR_head_seq_ptr + num_heads * M_blk_idxs + + if NO_D_MASK: + if NO_N_MASK: + q = tl.load(Q_blk_ptrs) + do = tl.load(DO_blk_ptrs) + dr = tl.load(DR_blk_ptrs) + neg_log_acc = tl.load(A_blk_ptrs, mask=M_mask) + else: + q = tl.load(Q_blk_ptrs, mask=M_mask[:, None]) + do = tl.load(DO_blk_ptrs, mask=M_mask[:, None]) + dr = tl.load(DR_blk_ptrs, mask=M_mask) + neg_log_acc = tl.load(A_blk_ptrs, mask=M_mask) + else: + MD_mask = M_mask[:, None] & D_mask[None, :] + q = tl.load(Q_blk_ptrs, mask=MD_mask) + do = tl.load(DO_blk_ptrs, mask=MD_mask) + dr = tl.load(DR_blk_ptrs, mask=M_mask) + neg_log_acc = tl.load(A_blk_ptrs, mask=M_mask) + + neg_log_acc = neg_log_acc.to(dtype=acc_dtype) + grad_prev_acc = tl.zeros((BLOCK_M,), dtype=acc_dtype) + dq = tl.zeros((BLOCK_M, BLOCK_D), dtype=acc_dtype) + + fwd_cm = tl.trans(cm) + iters = (block_start_offset + BLOCK_M) // BLOCK_N + for i in range(iters): + on_band = (iters - i - 1) < BLOCK_M // BLOCK_N + N_mask = N_blk_idxs < seq_length + local_no_n_mask = (N_blk_idxs_start + BLOCK_N - 1) < seq_length + k, v = load_kv( + K_blk_ptrs, + V_blk_ptrs, + N_mask=N_mask, + NO_N_MASK=local_no_n_mask, + D_mask=D_mask, + NO_D_MASK=NO_D_MASK, + ) + p, log_om_beta, neg_log_acc = compute_block( + q, + k, + qk_scale, + neg_log_acc, + M_blk_idxs, + N_blk_idxs, + cm, + on_band, + ALLOW_TF32, + attend_current=attend_current, + backward=True, + is_compiling=is_compiling, + ) + + if not NO_M_MASK: + neg_log_acc = tl.where(M_mask, neg_log_acc, 0.0) + + att_dA = p * (tl.dot(do, tl.trans(v), allow_tf32=ALLOW_TF32) - dr[:, None]) + cumul_att_dA = tl.dot(att_dA.to(cm.dtype), fwd_cm, allow_tf32=ALLOW_TF32) + grad_prev_acc[:, None] + grad_prev_acc += tl.sum(att_dA, axis=1) + beta = 1 - tl.exp2(log_om_beta) + dqk = att_dA - beta * cumul_att_dA + + dq = tl.dot(dqk.to(k.dtype), k, acc=dq, allow_tf32=ALLOW_TF32) + block_dk = tl.dot(tl.trans(dqk).to(q.dtype), q, allow_tf32=ALLOW_TF32) * logit_scale + block_dv = tl.dot(tl.trans(p), do.to(p.dtype), allow_tf32=ALLOW_TF32) + + if NO_D_MASK: + tl.store(DK_blk_ptrs, block_dk, mask=N_mask[:, None]) + tl.store(DV_blk_ptrs, block_dv, mask=N_mask[:, None]) + else: + mask = N_mask[:, None] & D_mask[None, :] + tl.store(DK_blk_ptrs, block_dk, mask=mask) + tl.store(DV_blk_ptrs, block_dv, mask=mask) + + N_blk_idxs += BLOCK_N + N_blk_idxs_start += BLOCK_N + K_blk_ptrs += BLOCK_N * (num_heads * head_size) + V_blk_ptrs += BLOCK_N * (num_heads * head_size) + DK_blk_ptrs += BLOCK_N * (num_heads * head_size) + DV_blk_ptrs += BLOCK_N * (num_heads * head_size) + + dq = (logit_scale * dq).to(DQ_head_seq_ptr.type.element_ty) + + if NO_D_MASK: + tl.store(DQ_blk_ptrs, dq, mask=M_mask[:, None]) + else: + tl.store(DQ_blk_ptrs, dq, mask=M_mask[:, None] & D_mask[None, :]) + + +def stickbreaking_attn_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + inv_temp: float, + attend_current: bool, + cu_seqlens: torch.LongTensor | None = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Run forward Triton kernel and return (o, rem, neg_log_acc). + + q, k, v: [B, T, H, D] + Returns: o [B, T, H, D], rem [B, T, H], neg_log_acc [B, T, H] + """ + batch_size, token_size, num_heads, dim_size = q.size() + o = torch.empty_like(q) + rem = torch.zeros_like(q[:, :, :, 0], device=q.device) + neg_log_acc = torch.zeros_like(rem, device=q.device, dtype=torch.float32) + + BLOCK_M = 64 + BLOCK_N = 64 + if cu_seqlens is None: + num_seq_blocks = triton.cdiv(token_size, BLOCK_M) + grid = (batch_size, num_heads, num_seq_blocks) + CI = None + else: + CI = prepare_chunk_indices(cu_seqlens, BLOCK_M) + num_seq_blocks = int(CI.shape[0]) + grid = (1, num_heads, num_seq_blocks) + BLOCK_D = triton.next_power_of_2(dim_size) + + stickbreaking_attn_fwd_kernel[grid]( + q, + k, + v, + o, + rem, + neg_log_acc, + CU_ptr=cu_seqlens if cu_seqlens is not None else q, + CI_ptr=CI if CI is not None else q, + logit_scale=inv_temp, + attend_current=attend_current, + batch_size=batch_size, + token_size=token_size, + head_size=dim_size, + num_heads=num_heads, + BLOCK_D=BLOCK_D, + NO_D_MASK=BLOCK_D == dim_size, + NO_M_MASK=(token_size % BLOCK_M) == 0, + NO_N_MASK=(token_size % BLOCK_N) == 0, + ALLOW_TF32=ALLOW_TF32, + inv_log2=inv_log2, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + no_grad=False, + is_compiling=False, + IS_VARLEN=cu_seqlens is not None, + ) + + return o, rem, neg_log_acc + + +def stickbreaking_attn_bwd( + do: torch.Tensor, + dr: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + neg_log_acc: torch.Tensor, + inv_temp: float, + attend_current: bool, + cu_seqlens: torch.LongTensor | None = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + batch_size, token_size, num_heads, dim_size = q.size() + BLOCK_M = 64 + BLOCK_N = 64 + if cu_seqlens is None: + M_count = triton.cdiv(token_size, BLOCK_M) + grid = (batch_size, num_heads, M_count) + CI = None + else: + CI = prepare_chunk_indices(cu_seqlens, BLOCK_M) + M_count = int(CI.shape[0]) + grid = (1, num_heads, M_count) + dq = torch.zeros_like(q) + dk = torch.zeros((M_count, batch_size, token_size, num_heads, dim_size), dtype=k.dtype, device=k.device) + dv = torch.zeros((M_count, batch_size, token_size, num_heads, dim_size), dtype=v.dtype, device=v.device) + + BLOCK_D = triton.next_power_of_2(dim_size) + stickbreaking_attn_bwd_kernel[grid]( + do, + dr, + neg_log_acc, + q, + k, + v, + dq, + dk, + dv, + CU_ptr=cu_seqlens if cu_seqlens is not None else q, + CI_ptr=CI if CI is not None else q, + logit_scale=inv_temp, + attend_current=attend_current, + batch_size=batch_size, + token_size=token_size, + head_size=dim_size, + num_heads=num_heads, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_D=BLOCK_D, + NO_D_MASK=BLOCK_D == dim_size, + NO_M_MASK=(token_size % BLOCK_M) == 0, + NO_N_MASK=(token_size % BLOCK_N) == 0, + ALLOW_TF32=ALLOW_TF32, + inv_log2=inv_log2, + acc_dtype=tl.float32, + is_compiling=False, + IS_VARLEN=cu_seqlens is not None, + ) + + dk_final = dk.sum(0) + dv_final = dv.sum(0) + + return dq.to(q.dtype), dk_final, dv_final + + +class StickBreakingAttentionFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + inv_temp: float, + attend_current: bool = False, + cu_seqlens: torch.LongTensor | None = None, + ): + o, rem, neg_log_acc = stickbreaking_attn_fwd(q, k, v, inv_temp, attend_current, cu_seqlens) + ctx.save_for_backward(q, k, v, neg_log_acc) + ctx.inv_temp = inv_temp + ctx.attend_current = attend_current + ctx.cu_seqlens = cu_seqlens + return o, rem + + @staticmethod + def backward(ctx, do: torch.Tensor, drem: torch.Tensor): + q, k, v, neg_log_acc = ctx.saved_tensors + dq, dk, dv = stickbreaking_attn_bwd(do, drem, q, k, v, neg_log_acc, ctx.inv_temp, ctx.attend_current, ctx.cu_seqlens) + return dq, dk, dv, None, None, None + + +def sb_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + inv_temp: float, + attend_current: bool = False, + cu_seqlens: torch.LongTensor | None = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + return StickBreakingAttentionFunction.apply(q, k, v, inv_temp, attend_current, cu_seqlens) + + +__all__ = [ + 'sb_attn', +] diff --git a/tests/models/test_modeling_stickbreaking_attn.py b/tests/models/test_modeling_stickbreaking_attn.py new file mode 100644 index 000000000..58973688c --- /dev/null +++ b/tests/models/test_modeling_stickbreaking_attn.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- + +import pytest +import torch + +from fla.models import StickBreakingAttentionConfig + +from .test_modeling_base import run_test_generation, run_test_model_forward_backward + + +# =================================================================================== +# Test for Modeling (Forward/Backward Pass) +# =================================================================================== +@pytest.mark.parametrize( + ['L', 'B', 'T', 'H', 'D', 'use_l2warp', 'dtype'], + [ + pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-use_l2warp{}-{}".format(*test)) + for test in [ + (4, 4, 1024, 4, 64, True, torch.bfloat16), + (4, 4, 1024, 4, 64, False, torch.bfloat16), + (4, 4, 1024, 4, 128, False, torch.bfloat16), + ] + ] +) +def test_modeling( + L: int, + B: int, + T: int, + H: int, + D: int, + use_l2warp: bool, + dtype: torch.dtype, +): + run_test_model_forward_backward(L, B, T, H, D, StickBreakingAttentionConfig, use_l2warp=use_l2warp, dtype=dtype) + + +# =================================================================================== +# Test for Generation +# =================================================================================== +@pytest.mark.parametrize( + ['L', 'B', 'T', 'H', 'D', 'dtype'], + [ + pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-{}".format(*test)) + for test in [ + (2, 4, 2000, 8, 64, torch.float16), + ] + ] +) +def test_generation( + L: int, + B: int, + T: int, + H: int, + D: int, + dtype: torch.dtype, +): + run_test_generation(L, B, T, H, D, StickBreakingAttentionConfig, dtype) diff --git a/tests/models/test_modeling_utils.py b/tests/models/test_modeling_utils.py index e0b81a9e6..e10e7c34c 100644 --- a/tests/models/test_modeling_utils.py +++ b/tests/models/test_modeling_utils.py @@ -24,7 +24,7 @@ GENERATION_UNSUPPORTED = [ "ABCConfig", "LinearAttentionConfig", "LightNetConfig", "Mamba2Config", "MambaConfig", "NSAConfig", "SambaConfig", "RWKV6Config", "RWKV7Config", - "DeltaFormerConfig", + "DeltaFormerConfig", "StickBreakingAttentionConfig", ] diff --git a/tests/ops/test_stickbreaking_attn.py b/tests/ops/test_stickbreaking_attn.py new file mode 100644 index 000000000..900739dfe --- /dev/null +++ b/tests/ops/test_stickbreaking_attn.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- + +import math + +import pytest +import torch + +from fla.ops import sb_attn, sb_attn_naive +from fla.utils import assert_close, device, is_intel_alchemist + + +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-{}".format(*test)) + for test in [ + (2, 64, 2, 64, torch.float32), + (1, 128, 4, 64, torch.bfloat16), + ] + ] +) +@pytest.mark.skipif( + is_intel_alchemist, + reason="Skipping test on Intel Alchemist due to known issues with SRAM." +) +def test_stickbreaking_attn( + B: int, + T: int, + H: int, + D: int, + dtype: torch.dtype +): + torch.manual_seed(42) + + q = torch.randn((B, T, H, D), dtype=dtype, device=device).requires_grad_(True) + k = torch.randn((B, T, H, D), dtype=dtype, device=device).requires_grad_(True) + v = torch.randn((B, T, H, D), dtype=dtype, device=device).requires_grad_(True) + + do = torch.randn((B, T, H, D), dtype=dtype, device=device) + dr = torch.randn((B, T, H), dtype=dtype, device=device) + + inv_temp = 1.0 / math.sqrt(D) + + # Reference (naive) + ref_o, ref_rem = sb_attn_naive(q, k, v, inv_temp, attend_current=False) + (ref_o * do).sum().backward(retain_graph=True) + (ref_rem * dr).sum().backward() + ref_dq, q.grad = q.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dv, v.grad = v.grad.clone(), None + + # Triton fused + tri_o, tri_rem = sb_attn(q, k, v, inv_temp=inv_temp, attend_current=False) + (tri_o * do).sum().backward(retain_graph=True) + (tri_rem * dr).sum().backward() + tri_dq, q.grad = q.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dv, v.grad = v.grad.clone(), None + + # Compare + assert_close(" o", ref_o, tri_o, 0.008) + assert_close("rem", ref_rem, tri_rem, 0.02) + assert_close("dq", ref_dq, tri_dq, 0.02) + assert_close("dk", ref_dk, tri_dk, 0.02) + assert_close("dv", ref_dv, tri_dv, 0.02) From cb99d40b256af83666caa23302022dcf947b208a Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 10 Nov 2025 08:22:56 +0000 Subject: [PATCH 02/10] Remove rope --- fla/__init__.py | 2 +- fla/layers/stickbreaking_attn.py | 6 +----- .../stickbreaking_attn/modeling_stickbreaking_attn.py | 7 ++++--- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/fla/__init__.py b/fla/__init__.py index a4b088381..82d3ee49b 100644 --- a/fla/__init__.py +++ b/fla/__init__.py @@ -26,7 +26,7 @@ RodimusAttention, RWKV6Attention, RWKV7Attention, - StickBreakingAttention + StickBreakingAttention, ) from fla.models import ( ABCForCausalLM, diff --git a/fla/layers/stickbreaking_attn.py b/fla/layers/stickbreaking_attn.py index 7b8fffbaf..785a9c408 100644 --- a/fla/layers/stickbreaking_attn.py +++ b/fla/layers/stickbreaking_attn.py @@ -32,7 +32,6 @@ def __init__( qkv_bias: bool = False, qk_norm: bool = False, window_size: Optional[int] = None, - rope_theta: Optional[float] = None, # sba doesn't use RoPE max_position_embeddings: Optional[int] = None, layer_idx: int | None = None, ): @@ -56,7 +55,6 @@ def __init__( self.qk_norm = qk_norm self.window_size = window_size - self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings self.layer_idx = layer_idx @@ -107,6 +105,4 @@ def forward( o = o.reshape(batch_size, q_len, -1) o = self.o_proj(o) - attentions = None - - return o, attentions, past_key_values + return o, None, past_key_values diff --git a/fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py b/fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py index 9fa4ff231..f892783f7 100644 --- a/fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py +++ b/fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py @@ -8,13 +8,15 @@ import torch import torch.nn as nn -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from transformers.utils.deprecation import deprecate_kwarg from fla.layers.stickbreaking_attn import StickBreakingAttention -from fla.models.stickbreaking_attn.configuration_stickbreaking_attn import StickBreakingAttentionConfig +from fla.models.stickbreaking_attn.configuration_stickbreaking_attn import \ + StickBreakingAttentionConfig from fla.models.utils import Cache, FLAGenerationMixin from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss from fla.modules import GatedMLP as SBAttnMLP @@ -49,7 +51,6 @@ def __init__(self, config: StickBreakingAttentionConfig, layer_idx: int): qkv_bias=config.qkv_bias, qk_norm=config.qk_norm, window_size=config.window_size, - rope_theta=None, max_position_embeddings=config.max_position_embeddings, layer_idx=layer_idx ) From 9b2554c2b1bd8b68c3f064e909cd9b5a92e946d3 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 10 Nov 2025 08:27:56 +0000 Subject: [PATCH 03/10] Fix lint --- fla/layers/stickbreaking_attn.py | 19 ++--- fla/models/__init__.py | 2 +- fla/models/stickbreaking_attn/__init__.py | 3 +- .../configuration_stickbreaking_attn.py | 19 ++--- .../modeling_stickbreaking_attn.py | 83 +++++++++---------- fla/ops/stickbreaking_attn/__init__.py | 2 - fla/ops/stickbreaking_attn/naive.py | 6 +- fla/ops/stickbreaking_attn/parallel.py | 12 ++- .../test_modeling_stickbreaking_attn.py | 6 +- tests/ops/test_stickbreaking_attn.py | 8 +- 10 files changed, 70 insertions(+), 90 deletions(-) diff --git a/fla/layers/stickbreaking_attn.py b/fla/layers/stickbreaking_attn.py index 785a9c408..19c0db622 100644 --- a/fla/layers/stickbreaking_attn.py +++ b/fla/layers/stickbreaking_attn.py @@ -1,11 +1,10 @@ -# -*- coding: utf-8 -*- # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang from __future__ import annotations import math import warnings -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING import torch import torch.nn as nn @@ -28,18 +27,18 @@ def __init__( self, hidden_size: int = 2048, num_heads: int = 32, - num_kv_heads: Optional[int] = None, + num_kv_heads: int | None = None, qkv_bias: bool = False, qk_norm: bool = False, - window_size: Optional[int] = None, - max_position_embeddings: Optional[int] = None, + window_size: int | None = None, + max_position_embeddings: int | None = None, layer_idx: int | None = None, ): super().__init__() if sb_attn is None: raise ImportError( - "StickBreakingAttention kernels are not available. Ensure Triton is installed and ops are importable." + "StickBreakingAttention kernels are not available. Ensure Triton is installed and ops are importable.", ) self.hidden_size = hidden_size @@ -70,13 +69,13 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, + attention_mask: torch.LongTensor | None = None, + past_key_values: Cache | None = None, output_attentions: bool = False, use_cache: bool = False, attend_current: bool = False, **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: if attention_mask is not None: assert len(attention_mask.shape) == 2, ( "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " @@ -100,7 +99,7 @@ def forward( inv_temp = 1.0 / math.sqrt(self.head_dim) - cu_seqlens = kwargs.get('cu_seqlens', None) + cu_seqlens = kwargs.get('cu_seqlens') o, _rem = sb_attn(q, k, v, inv_temp=inv_temp, attend_current=attend_current, cu_seqlens=cu_seqlens) o = o.reshape(batch_size, q_len, -1) o = self.o_proj(o) diff --git a/fla/models/__init__.py b/fla/models/__init__.py index 34b24f4f3..a95d7d1d2 100644 --- a/fla/models/__init__.py +++ b/fla/models/__init__.py @@ -34,7 +34,7 @@ from fla.models.stickbreaking_attn import ( StickBreakingAttentionConfig, StickBreakingAttentionForCausalLM, - StickBreakingAttentionModel + StickBreakingAttentionModel, ) from fla.models.transformer import TransformerConfig, TransformerForCausalLM, TransformerModel diff --git a/fla/models/stickbreaking_attn/__init__.py b/fla/models/stickbreaking_attn/__init__.py index 35753f62e..f9fdf7297 100644 --- a/fla/models/stickbreaking_attn/__init__.py +++ b/fla/models/stickbreaking_attn/__init__.py @@ -1,11 +1,10 @@ -# -*- coding: utf-8 -*- from transformers import AutoConfig, AutoModel, AutoModelForCausalLM from fla.models.stickbreaking_attn.configuration_stickbreaking_attn import StickBreakingAttentionConfig from fla.models.stickbreaking_attn.modeling_stickbreaking_attn import ( StickBreakingAttentionForCausalLM, - StickBreakingAttentionModel + StickBreakingAttentionModel, ) AutoConfig.register(StickBreakingAttentionConfig.model_type, StickBreakingAttentionConfig, exist_ok=True) diff --git a/fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py b/fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py index 141db9e0c..439cc38f2 100644 --- a/fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py +++ b/fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py @@ -1,7 +1,4 @@ -# -*- coding: utf-8 -*- - import warnings -from typing import Optional from transformers.configuration_utils import PretrainedConfig @@ -16,19 +13,19 @@ def __init__( hidden_size: int = 2048, num_hidden_layers: int = 24, num_heads: int = 32, - num_kv_heads: Optional[int] = None, + num_kv_heads: int | None = None, qkv_bias: bool = False, qk_norm: bool = False, - window_size: Optional[int] = None, + window_size: int | None = None, max_position_embeddings: int = 2048, - hidden_ratio: Optional[int] = 4, - intermediate_size: Optional[int] = None, + hidden_ratio: int | None = 4, + intermediate_size: int | None = None, hidden_act: str = "swish", initializer_range: float = 0.02, - elementwise_affine: Optional[bool] = True, + elementwise_affine: bool | None = True, norm_eps: float = 1e-6, use_cache: bool = True, - pad_token_id: Optional[int] = None, + pad_token_id: int | None = None, bos_token_id: int = 1, eos_token_id: int = 2, tie_word_embeddings: bool = False, @@ -67,13 +64,13 @@ def __init__( if fuse_cross_entropy and fuse_linear_cross_entropy: raise ValueError( - "`fuse_cross_entropy` and `fuse_linear_cross_entropy` cannot be True at the same time." + "`fuse_cross_entropy` and `fuse_linear_cross_entropy` cannot be True at the same time.", ) if fuse_linear_cross_entropy: warnings.warn( "`fuse_linear_cross_entropy` is enabled, which can improves memory efficiency " "at the potential cost of reduced precision. " - "If you observe issues like loss divergence, consider disabling this setting." + "If you observe issues like loss divergence, consider disabling this setting.", ) super().__init__( diff --git a/fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py b/fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py index f892783f7..a72e1ae0e 100644 --- a/fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py +++ b/fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py @@ -1,26 +1,21 @@ -# -*- coding: utf-8 -*- - from __future__ import annotations import math import warnings -from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any import torch import torch.nn as nn -from transformers.modeling_outputs import (BaseModelOutputWithPast, - CausalLMOutputWithPast) +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from transformers.utils.deprecation import deprecate_kwarg from fla.layers.stickbreaking_attn import StickBreakingAttention -from fla.models.stickbreaking_attn.configuration_stickbreaking_attn import \ - StickBreakingAttentionConfig +from fla.models.stickbreaking_attn.configuration_stickbreaking_attn import StickBreakingAttentionConfig from fla.models.utils import Cache, FLAGenerationMixin -from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm from fla.modules import GatedMLP as SBAttnMLP -from fla.modules import RMSNorm from fla.modules.l2warp import l2_warp if TYPE_CHECKING: @@ -52,7 +47,7 @@ def __init__(self, config: StickBreakingAttentionConfig, layer_idx: int): qk_norm=config.qk_norm, window_size=config.window_size, max_position_embeddings=config.max_position_embeddings, - layer_idx=layer_idx + layer_idx=layer_idx, ) self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) @@ -61,18 +56,18 @@ def __init__(self, config: StickBreakingAttentionConfig, layer_idx: int): hidden_ratio=config.hidden_ratio, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - fuse_swiglu=config.fuse_swiglu + fuse_swiglu=config.fuse_swiglu, ) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - **kwargs: Unpack[Any] - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + attention_mask: torch.Tensor | None = None, + past_key_values: tuple[torch.Tensor] | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + **kwargs: Unpack[Any], + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: residual = hidden_states hidden_states = self.attn_norm(hidden_states) @@ -82,7 +77,7 @@ def forward( past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, - **kwargs + **kwargs, ) if self.config.fuse_norm: hidden_states, residual = self.mlp_norm(hidden_states, residual, True) @@ -146,7 +141,7 @@ class StickBreakingAttentionModel(StickBreakingAttentionPreTrainedModel): def __init__( self, - config: StickBreakingAttentionConfig + config: StickBreakingAttentionConfig, ) -> StickBreakingAttentionModel: super().__init__(config) self.padding_idx = config.pad_token_id @@ -171,19 +166,19 @@ def set_input_embeddings(self, value): def forward( self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs: Unpack[Any] - ) -> Union[Tuple, CausalLMOutputWithPast]: + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + **kwargs: Unpack[Any], + ) -> tuple | CausalLMOutputWithPast: if output_attentions: warnings.warn( - "`sba` does not support output attention weights now, so `output_attentions` is set to `False`." + "`sba` does not support output attention weights now, so `output_attentions` is set to `False`.", ) output_attentions = False output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -218,7 +213,7 @@ def forward( past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, - **kwargs + **kwargs, ) hidden_states = layer_outputs[0] @@ -241,7 +236,7 @@ def forward( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, - attentions=all_attns + attentions=all_attns, ) @@ -280,17 +275,17 @@ def get_decoder(self): def forward( self, input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - logits_to_keep: Optional[int] = 0, - **kwargs: Unpack[Any] - ) -> Union[Tuple, CausalLMOutputWithPast]: + attention_mask: torch.Tensor | None = None, + past_key_values: Cache | list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + logits_to_keep: int | None = 0, + **kwargs: Unpack[Any], + ) -> tuple | CausalLMOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -306,7 +301,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - **kwargs + **kwargs, ) hidden_states = outputs[0] diff --git a/fla/ops/stickbreaking_attn/__init__.py b/fla/ops/stickbreaking_attn/__init__.py index ac1e4130c..e2dcc86eb 100644 --- a/fla/ops/stickbreaking_attn/__init__.py +++ b/fla/ops/stickbreaking_attn/__init__.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - from .naive import sb_attn_naive from .parallel import sb_attn diff --git a/fla/ops/stickbreaking_attn/naive.py b/fla/ops/stickbreaking_attn/naive.py index f02dfd4ca..4fe936cc3 100644 --- a/fla/ops/stickbreaking_attn/naive.py +++ b/fla/ops/stickbreaking_attn/naive.py @@ -1,7 +1,5 @@ -# -*- coding: utf-8 -*- # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang -from typing import Tuple import torch @@ -17,8 +15,8 @@ def sb_attn_naive( k: torch.Tensor, v: torch.Tensor, inv_temp: float, - attend_current: bool = False -) -> Tuple[torch.Tensor, torch.Tensor]: + attend_current: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: """ Naive stick-breaking attention reference implementation. diff --git a/fla/ops/stickbreaking_attn/parallel.py b/fla/ops/stickbreaking_attn/parallel.py index dadb95673..391c4e3c1 100644 --- a/fla/ops/stickbreaking_attn/parallel.py +++ b/fla/ops/stickbreaking_attn/parallel.py @@ -1,8 +1,6 @@ -# -*- coding: utf-8 -*- # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang import math -from typing import Tuple import torch import triton @@ -558,7 +556,7 @@ def stickbreaking_attn_fwd( inv_temp: float, attend_current: bool, cu_seqlens: torch.LongTensor | None = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Run forward Triton kernel and return (o, rem, neg_log_acc). @@ -598,7 +596,7 @@ def stickbreaking_attn_fwd( head_size=dim_size, num_heads=num_heads, BLOCK_D=BLOCK_D, - NO_D_MASK=BLOCK_D == dim_size, + NO_D_MASK=dim_size == BLOCK_D, NO_M_MASK=(token_size % BLOCK_M) == 0, NO_N_MASK=(token_size % BLOCK_N) == 0, ALLOW_TF32=ALLOW_TF32, @@ -623,7 +621,7 @@ def stickbreaking_attn_bwd( inv_temp: float, attend_current: bool, cu_seqlens: torch.LongTensor | None = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: batch_size, token_size, num_heads, dim_size = q.size() BLOCK_M = 64 BLOCK_N = 64 @@ -661,7 +659,7 @@ def stickbreaking_attn_bwd( BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_D=BLOCK_D, - NO_D_MASK=BLOCK_D == dim_size, + NO_D_MASK=dim_size == BLOCK_D, NO_M_MASK=(token_size % BLOCK_M) == 0, NO_N_MASK=(token_size % BLOCK_N) == 0, ALLOW_TF32=ALLOW_TF32, @@ -709,7 +707,7 @@ def sb_attn( inv_temp: float, attend_current: bool = False, cu_seqlens: torch.LongTensor | None = None, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: return StickBreakingAttentionFunction.apply(q, k, v, inv_temp, attend_current, cu_seqlens) diff --git a/tests/models/test_modeling_stickbreaking_attn.py b/tests/models/test_modeling_stickbreaking_attn.py index 58973688c..a56ae2672 100644 --- a/tests/models/test_modeling_stickbreaking_attn.py +++ b/tests/models/test_modeling_stickbreaking_attn.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import pytest import torch @@ -20,7 +18,7 @@ (4, 4, 1024, 4, 64, False, torch.bfloat16), (4, 4, 1024, 4, 128, False, torch.bfloat16), ] - ] + ], ) def test_modeling( L: int, @@ -44,7 +42,7 @@ def test_modeling( for test in [ (2, 4, 2000, 8, 64, torch.float16), ] - ] + ], ) def test_generation( L: int, diff --git a/tests/ops/test_stickbreaking_attn.py b/tests/ops/test_stickbreaking_attn.py index 900739dfe..49296c408 100644 --- a/tests/ops/test_stickbreaking_attn.py +++ b/tests/ops/test_stickbreaking_attn.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import math import pytest @@ -17,18 +15,18 @@ (2, 64, 2, 64, torch.float32), (1, 128, 4, 64, torch.bfloat16), ] - ] + ], ) @pytest.mark.skipif( is_intel_alchemist, - reason="Skipping test on Intel Alchemist due to known issues with SRAM." + reason="Skipping test on Intel Alchemist due to known issues with SRAM.", ) def test_stickbreaking_attn( B: int, T: int, H: int, D: int, - dtype: torch.dtype + dtype: torch.dtype, ): torch.manual_seed(42) From b666a57c67060b5b9a4e328fad3254fda853ae2b Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 10 Nov 2025 08:38:01 +0000 Subject: [PATCH 04/10] Fix names --- fla/layers/stickbreaking_attn.py | 15 +++++---- fla/ops/__init__.py | 13 ++++---- fla/ops/stickbreaking_attn/__init__.py | 8 ++--- fla/ops/stickbreaking_attn/naive.py | 26 +++++++--------- fla/ops/stickbreaking_attn/parallel.py | 42 ++++++++++++++------------ tests/ops/test_stickbreaking_attn.py | 6 ++-- 6 files changed, 55 insertions(+), 55 deletions(-) diff --git a/fla/layers/stickbreaking_attn.py b/fla/layers/stickbreaking_attn.py index 19c0db622..8e929db3f 100644 --- a/fla/layers/stickbreaking_attn.py +++ b/fla/layers/stickbreaking_attn.py @@ -2,7 +2,6 @@ from __future__ import annotations -import math import warnings from typing import TYPE_CHECKING @@ -12,7 +11,7 @@ from transformers.utils import logging from fla.modules import RMSNorm -from fla.ops.stickbreaking_attn import sb_attn +from fla.ops.stickbreaking_attn import parallel_stickbreaking_attn if TYPE_CHECKING: from fla.models.utils import Cache @@ -36,7 +35,7 @@ def __init__( ): super().__init__() - if sb_attn is None: + if parallel_stickbreaking_attn is None: raise ImportError( "StickBreakingAttention kernels are not available. Ensure Triton is installed and ops are importable.", ) @@ -97,10 +96,14 @@ def forward( if self.qk_norm: q, k = self.q_norm(q), self.k_norm(k) - inv_temp = 1.0 / math.sqrt(self.head_dim) - cu_seqlens = kwargs.get('cu_seqlens') - o, _rem = sb_attn(q, k, v, inv_temp=inv_temp, attend_current=attend_current, cu_seqlens=cu_seqlens) + o, _rem = parallel_stickbreaking_attn( + q=q, + k=k, + v=v, + attend_current=attend_current, + cu_seqlens=cu_seqlens, + ) o = o.reshape(batch_size, q_len, -1) o = self.o_proj(o) diff --git a/fla/ops/__init__.py b/fla/ops/__init__.py index b62bbc02d..a773016f4 100644 --- a/fla/ops/__init__.py +++ b/fla/ops/__init__.py @@ -7,10 +7,10 @@ from .forgetting_attn import parallel_forgetting_attn from .gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule from .generalized_delta_rule import ( - chunk_dplr_delta_rule, - chunk_iplr_delta_rule, - fused_recurrent_dplr_delta_rule, - fused_recurrent_iplr_delta_rule, + chunk_dplr_delta_rule, + chunk_iplr_delta_rule, + fused_recurrent_dplr_delta_rule, + fused_recurrent_iplr_delta_rule, ) from .gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla from .gsa import chunk_gsa, fused_recurrent_gsa @@ -26,8 +26,7 @@ from .rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6 from .rwkv7 import chunk_rwkv7, fused_recurrent_rwkv7 from .simple_gla import chunk_simple_gla, fused_chunk_simple_gla, fused_recurrent_simple_gla, parallel_simple_gla -from .stickbreaking_attn.naive import sb_attn_naive -from .stickbreaking_attn.parallel import sb_attn +from .stickbreaking_attn.parallel import parallel_stickbreaking_attn __all__ = [ 'chunk_abc', @@ -53,5 +52,5 @@ 'chunk_rwkv6', 'fused_recurrent_rwkv6', 'chunk_rwkv7', 'fused_recurrent_rwkv7', 'chunk_simple_gla', 'fused_chunk_simple_gla', 'fused_recurrent_simple_gla', 'parallel_simple_gla', - 'sb_attn', 'sb_attn_naive', + 'parallel_stickbreaking_attn', ] diff --git a/fla/ops/stickbreaking_attn/__init__.py b/fla/ops/stickbreaking_attn/__init__.py index e2dcc86eb..16472bf23 100644 --- a/fla/ops/stickbreaking_attn/__init__.py +++ b/fla/ops/stickbreaking_attn/__init__.py @@ -1,7 +1,7 @@ -from .naive import sb_attn_naive -from .parallel import sb_attn +from .naive import naive_stickbreaking_attn +from .parallel import parallel_stickbreaking_attn __all__ = [ - 'sb_attn', - 'sb_attn_naive', + 'parallel_stickbreaking_attn', + 'naive_stickbreaking_attn', ] diff --git a/fla/ops/stickbreaking_attn/naive.py b/fla/ops/stickbreaking_attn/naive.py index 4fe936cc3..52db444b8 100644 --- a/fla/ops/stickbreaking_attn/naive.py +++ b/fla/ops/stickbreaking_attn/naive.py @@ -1,20 +1,14 @@ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang - import torch +import torch.nn.functional as F -def _tril_mask(T: int, strict: bool = True, device=None) -> torch.Tensor: - i = torch.arange(T, device=device).view(1, 1, T, 1) - j = torch.arange(T, device=device).view(1, 1, 1, T) - return (j < i) if strict else (j <= i) - - -def sb_attn_naive( +def naive_stickbreaking_attn( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - inv_temp: float, + scale: float | None = None, attend_current: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -22,17 +16,19 @@ def sb_attn_naive( Args: q, k, v: [B, T, H, D] - inv_temp: inverse temperature (1/sqrt(D)) + scale: inverse temperature (1/sqrt(D)) attend_current: include diagonal when computing weights Returns: o: [B, T, H, D] rem: [B, T, H] (1 - sum of attention up to t) """ - B, T, H, D = q.shape + _, T, _, D = q.shape orig_dtype = q.dtype + if scale is None: + scale = D ** -0.5 - logits = torch.einsum('bthd,bshd->bhts', q, k) * inv_temp + logits = torch.einsum('bthd,bshd->bhts', q, k) * scale logits = logits.float() if attend_current: @@ -41,8 +37,8 @@ def sb_attn_naive( mask = torch.ones(T, T, device=q.device).triu(0).bool() # include diagonal mask = mask.unsqueeze(0).unsqueeze(0) # [1, 1, T, T] - log_z = torch.nn.functional.logsigmoid(logits).masked_fill(mask, -1e5).to(orig_dtype) - log_beta = torch.nn.functional.logsigmoid(-logits).masked_fill(mask, 0).to(orig_dtype) + log_z = F.logsigmoid(logits).masked_fill(mask, -1e5).to(orig_dtype) + log_beta = F.logsigmoid(-logits).masked_fill(mask, 0).to(orig_dtype) cum_weight = torch.ones(T, T, device=q.device).tril(-1) @@ -56,5 +52,5 @@ def sb_attn_naive( __all__ = [ - 'sb_attn_naive', + 'naive_stickbreaking_attn', ] diff --git a/fla/ops/stickbreaking_attn/parallel.py b/fla/ops/stickbreaking_attn/parallel.py index 391c4e3c1..36238a565 100644 --- a/fla/ops/stickbreaking_attn/parallel.py +++ b/fla/ops/stickbreaking_attn/parallel.py @@ -27,7 +27,7 @@ def stickbreaking_attn_fwd_kernel( A_ptr, CU_ptr, CI_ptr, - logit_scale: tl.constexpr, + scale: tl.constexpr, attend_current: tl.constexpr, batch_size, token_size, @@ -61,7 +61,7 @@ def stickbreaking_attn_fwd_kernel( bos = 0 seq_block_id = prog_id seq_length = token_size - qk_scale = inv_log2 * logit_scale + qk_scale = inv_log2 * scale M_range = tl.arange(0, BLOCK_M) N_range = tl.arange(0, BLOCK_N) D_range = tl.arange(0, BLOCK_D) @@ -237,7 +237,7 @@ def stickbreaking_attn_bwd_kernel( DV_ptr, CU_ptr, CI_ptr, - logit_scale, + scale, batch_size, token_size, head_size: tl.constexpr, @@ -259,7 +259,7 @@ def stickbreaking_attn_bwd_kernel( batch_id = 0 if IS_VARLEN else tl.program_id(0) head_pid = tl.program_id(1) prog_id = tl.program_id(2) - qk_scale = inv_log2 * logit_scale + qk_scale = inv_log2 * scale M_range = tl.arange(0, BLOCK_M) N_range = tl.arange(0, BLOCK_N) D_range = tl.arange(0, BLOCK_D) @@ -313,7 +313,7 @@ def stickbreaking_attn_bwd_kernel( DQ_head_seq_ptr, DK_head_seq_ptr, DV_head_seq_ptr, - logit_scale, + scale, head_size, num_heads, BLOCK_D, @@ -417,7 +417,7 @@ def stickbreaking_attn_bwd_one_row_kernel( DQ_head_seq_ptr, DK_head_seq_ptr, DV_head_seq_ptr, - logit_scale, + scale, head_size: tl.constexpr, num_heads: tl.constexpr, BLOCK_D: tl.constexpr, @@ -523,7 +523,7 @@ def stickbreaking_attn_bwd_one_row_kernel( dqk = att_dA - beta * cumul_att_dA dq = tl.dot(dqk.to(k.dtype), k, acc=dq, allow_tf32=ALLOW_TF32) - block_dk = tl.dot(tl.trans(dqk).to(q.dtype), q, allow_tf32=ALLOW_TF32) * logit_scale + block_dk = tl.dot(tl.trans(dqk).to(q.dtype), q, allow_tf32=ALLOW_TF32) * scale block_dv = tl.dot(tl.trans(p), do.to(p.dtype), allow_tf32=ALLOW_TF32) if NO_D_MASK: @@ -541,7 +541,7 @@ def stickbreaking_attn_bwd_one_row_kernel( DK_blk_ptrs += BLOCK_N * (num_heads * head_size) DV_blk_ptrs += BLOCK_N * (num_heads * head_size) - dq = (logit_scale * dq).to(DQ_head_seq_ptr.type.element_ty) + dq = (scale * dq).to(DQ_head_seq_ptr.type.element_ty) if NO_D_MASK: tl.store(DQ_blk_ptrs, dq, mask=M_mask[:, None]) @@ -553,7 +553,7 @@ def stickbreaking_attn_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - inv_temp: float, + scale: float, attend_current: bool, cu_seqlens: torch.LongTensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -589,7 +589,7 @@ def stickbreaking_attn_fwd( neg_log_acc, CU_ptr=cu_seqlens if cu_seqlens is not None else q, CI_ptr=CI if CI is not None else q, - logit_scale=inv_temp, + scale=scale, attend_current=attend_current, batch_size=batch_size, token_size=token_size, @@ -618,7 +618,7 @@ def stickbreaking_attn_bwd( k: torch.Tensor, v: torch.Tensor, neg_log_acc: torch.Tensor, - inv_temp: float, + scale: float, attend_current: bool, cu_seqlens: torch.LongTensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -650,7 +650,7 @@ def stickbreaking_attn_bwd( dv, CU_ptr=cu_seqlens if cu_seqlens is not None else q, CI_ptr=CI if CI is not None else q, - logit_scale=inv_temp, + scale=scale, attend_current=attend_current, batch_size=batch_size, token_size=token_size, @@ -682,13 +682,13 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - inv_temp: float, + scale: float, attend_current: bool = False, cu_seqlens: torch.LongTensor | None = None, ): - o, rem, neg_log_acc = stickbreaking_attn_fwd(q, k, v, inv_temp, attend_current, cu_seqlens) + o, rem, neg_log_acc = stickbreaking_attn_fwd(q, k, v, scale, attend_current, cu_seqlens) ctx.save_for_backward(q, k, v, neg_log_acc) - ctx.inv_temp = inv_temp + ctx.scale = scale ctx.attend_current = attend_current ctx.cu_seqlens = cu_seqlens return o, rem @@ -696,21 +696,23 @@ def forward( @staticmethod def backward(ctx, do: torch.Tensor, drem: torch.Tensor): q, k, v, neg_log_acc = ctx.saved_tensors - dq, dk, dv = stickbreaking_attn_bwd(do, drem, q, k, v, neg_log_acc, ctx.inv_temp, ctx.attend_current, ctx.cu_seqlens) + dq, dk, dv = stickbreaking_attn_bwd(do, drem, q, k, v, neg_log_acc, ctx.scale, ctx.attend_current, ctx.cu_seqlens) return dq, dk, dv, None, None, None -def sb_attn( +def parallel_stickbreaking_attn( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - inv_temp: float, + scale: float | None = None, attend_current: bool = False, cu_seqlens: torch.LongTensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - return StickBreakingAttentionFunction.apply(q, k, v, inv_temp, attend_current, cu_seqlens) + if scale is None: + scale = q.shape[-1] ** -0.5 + return StickBreakingAttentionFunction.apply(q, k, v, scale, attend_current, cu_seqlens) __all__ = [ - 'sb_attn', + 'parallel_stickbreaking_attn', ] diff --git a/tests/ops/test_stickbreaking_attn.py b/tests/ops/test_stickbreaking_attn.py index 49296c408..35d41fab3 100644 --- a/tests/ops/test_stickbreaking_attn.py +++ b/tests/ops/test_stickbreaking_attn.py @@ -3,7 +3,7 @@ import pytest import torch -from fla.ops import sb_attn, sb_attn_naive +from fla.ops import naive_stickbreaking_attn, parallel_stickbreaking_attn from fla.utils import assert_close, device, is_intel_alchemist @@ -40,7 +40,7 @@ def test_stickbreaking_attn( inv_temp = 1.0 / math.sqrt(D) # Reference (naive) - ref_o, ref_rem = sb_attn_naive(q, k, v, inv_temp, attend_current=False) + ref_o, ref_rem = naive_stickbreaking_attn(q, k, v, inv_temp, attend_current=False) (ref_o * do).sum().backward(retain_graph=True) (ref_rem * dr).sum().backward() ref_dq, q.grad = q.grad.clone(), None @@ -48,7 +48,7 @@ def test_stickbreaking_attn( ref_dv, v.grad = v.grad.clone(), None # Triton fused - tri_o, tri_rem = sb_attn(q, k, v, inv_temp=inv_temp, attend_current=False) + tri_o, tri_rem = parallel_stickbreaking_attn(q, k, v, inv_temp=inv_temp, attend_current=False) (tri_o * do).sum().backward(retain_graph=True) (tri_rem * dr).sum().backward() tri_dq, q.grad = q.grad.clone(), None From 1f6b4642a37e8aee9791e53d49f275fb3549581a Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 10 Nov 2025 08:40:03 +0000 Subject: [PATCH 05/10] Fix names --- tests/ops/test_stickbreaking_attn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/ops/test_stickbreaking_attn.py b/tests/ops/test_stickbreaking_attn.py index 35d41fab3..0bf1c13c9 100644 --- a/tests/ops/test_stickbreaking_attn.py +++ b/tests/ops/test_stickbreaking_attn.py @@ -37,10 +37,10 @@ def test_stickbreaking_attn( do = torch.randn((B, T, H, D), dtype=dtype, device=device) dr = torch.randn((B, T, H), dtype=dtype, device=device) - inv_temp = 1.0 / math.sqrt(D) + scale = 1.0 / math.sqrt(D) # Reference (naive) - ref_o, ref_rem = naive_stickbreaking_attn(q, k, v, inv_temp, attend_current=False) + ref_o, ref_rem = naive_stickbreaking_attn(q, k, v, scale, attend_current=False) (ref_o * do).sum().backward(retain_graph=True) (ref_rem * dr).sum().backward() ref_dq, q.grad = q.grad.clone(), None @@ -48,7 +48,7 @@ def test_stickbreaking_attn( ref_dv, v.grad = v.grad.clone(), None # Triton fused - tri_o, tri_rem = parallel_stickbreaking_attn(q, k, v, inv_temp=inv_temp, attend_current=False) + tri_o, tri_rem = parallel_stickbreaking_attn(q, k, v, scale=scale, attend_current=False) (tri_o * do).sum().backward(retain_graph=True) (tri_rem * dr).sum().backward() tri_dq, q.grad = q.grad.clone(), None From b15eaf66176b1bd5f355f7064483d932472ffa56 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 10 Nov 2025 10:50:31 +0000 Subject: [PATCH 06/10] Fix imports --- tests/ops/test_stickbreaking_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ops/test_stickbreaking_attn.py b/tests/ops/test_stickbreaking_attn.py index 0bf1c13c9..f3baf876a 100644 --- a/tests/ops/test_stickbreaking_attn.py +++ b/tests/ops/test_stickbreaking_attn.py @@ -3,7 +3,7 @@ import pytest import torch -from fla.ops import naive_stickbreaking_attn, parallel_stickbreaking_attn +from fla.ops.stickbreaking_attn import naive_stickbreaking_attn, parallel_stickbreaking_attn from fla.utils import assert_close, device, is_intel_alchemist From 93c98b36264c2e36e28997a19d618ebf0097854b Mon Sep 17 00:00:00 2001 From: Nathancgy4 Date: Mon, 10 Nov 2025 12:27:55 +0000 Subject: [PATCH 07/10] Fixed sba varlen masking and remove unnecessary param --- fla/layers/stickbreaking_attn.py | 2 - fla/ops/stickbreaking_attn/naive.py | 8 +- fla/ops/stickbreaking_attn/parallel.py | 84 ++- sba_code/.gitignore | 162 +++++ sba_code/LICENSE | 201 ++++++ sba_code/README.md | 32 + sba_code/benchmarks/attn.py | 96 +++ sba_code/benchmarks/varlen.py | 129 ++++ sba_code/load_model_with_dolomite_demo.py | 7 + sba_code/setup.py | 29 + sba_code/stickbreaking_attention/__init__.py | 2 + .../sb_attn/__init__.py | 64 ++ .../stickbreaking_attention/sb_attn/sb_bwd.py | 297 ++++++++ .../stickbreaking_attention/sb_attn/sb_fwd.py | 253 +++++++ sba_code/stickbreaking_attention/sb_ref.py | 25 + .../sb_varlen/__init__.py | 82 +++ .../sb_varlen/sb_varlen_bwd.py | 641 ++++++++++++++++++ .../sb_varlen/sb_varlen_fwd.py | 522 ++++++++++++++ .../sb_varlen/softplus.py | 52 ++ sba_code/stickbreaking_attention/utils.py | 39 ++ sba_code/tests/__init__.py | 0 sba_code/tests/test_attn.py | 74 ++ sba_code/tests/test_varlen.py | 110 +++ tests/ops/test_stickbreaking_attn.py | 83 ++- 24 files changed, 2938 insertions(+), 56 deletions(-) create mode 100644 sba_code/.gitignore create mode 100644 sba_code/LICENSE create mode 100644 sba_code/README.md create mode 100644 sba_code/benchmarks/attn.py create mode 100644 sba_code/benchmarks/varlen.py create mode 100644 sba_code/load_model_with_dolomite_demo.py create mode 100644 sba_code/setup.py create mode 100644 sba_code/stickbreaking_attention/__init__.py create mode 100644 sba_code/stickbreaking_attention/sb_attn/__init__.py create mode 100644 sba_code/stickbreaking_attention/sb_attn/sb_bwd.py create mode 100644 sba_code/stickbreaking_attention/sb_attn/sb_fwd.py create mode 100644 sba_code/stickbreaking_attention/sb_ref.py create mode 100644 sba_code/stickbreaking_attention/sb_varlen/__init__.py create mode 100644 sba_code/stickbreaking_attention/sb_varlen/sb_varlen_bwd.py create mode 100644 sba_code/stickbreaking_attention/sb_varlen/sb_varlen_fwd.py create mode 100644 sba_code/stickbreaking_attention/sb_varlen/softplus.py create mode 100644 sba_code/stickbreaking_attention/utils.py create mode 100644 sba_code/tests/__init__.py create mode 100644 sba_code/tests/test_attn.py create mode 100644 sba_code/tests/test_varlen.py diff --git a/fla/layers/stickbreaking_attn.py b/fla/layers/stickbreaking_attn.py index 8e929db3f..fd81bbee0 100644 --- a/fla/layers/stickbreaking_attn.py +++ b/fla/layers/stickbreaking_attn.py @@ -72,7 +72,6 @@ def forward( past_key_values: Cache | None = None, output_attentions: bool = False, use_cache: bool = False, - attend_current: bool = False, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: if attention_mask is not None: @@ -101,7 +100,6 @@ def forward( q=q, k=k, v=v, - attend_current=attend_current, cu_seqlens=cu_seqlens, ) o = o.reshape(batch_size, q_len, -1) diff --git a/fla/ops/stickbreaking_attn/naive.py b/fla/ops/stickbreaking_attn/naive.py index 52db444b8..4872ecc67 100644 --- a/fla/ops/stickbreaking_attn/naive.py +++ b/fla/ops/stickbreaking_attn/naive.py @@ -9,7 +9,6 @@ def naive_stickbreaking_attn( k: torch.Tensor, v: torch.Tensor, scale: float | None = None, - attend_current: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ Naive stick-breaking attention reference implementation. @@ -17,8 +16,6 @@ def naive_stickbreaking_attn( Args: q, k, v: [B, T, H, D] scale: inverse temperature (1/sqrt(D)) - attend_current: include diagonal when computing weights - Returns: o: [B, T, H, D] rem: [B, T, H] (1 - sum of attention up to t) @@ -31,10 +28,7 @@ def naive_stickbreaking_attn( logits = torch.einsum('bthd,bshd->bhts', q, k) * scale logits = logits.float() - if attend_current: - mask = torch.ones(T, T, device=q.device).triu(1).bool() # exclude diagonal - else: - mask = torch.ones(T, T, device=q.device).triu(0).bool() # include diagonal + mask = torch.ones(T, T, device=q.device).triu(0).bool() # exclude diagonal mask = mask.unsqueeze(0).unsqueeze(0) # [1, 1, T, T] log_z = F.logsigmoid(logits).masked_fill(mask, -1e5).to(orig_dtype) diff --git a/fla/ops/stickbreaking_attn/parallel.py b/fla/ops/stickbreaking_attn/parallel.py index 36238a565..0642aaa47 100644 --- a/fla/ops/stickbreaking_attn/parallel.py +++ b/fla/ops/stickbreaking_attn/parallel.py @@ -28,7 +28,6 @@ def stickbreaking_attn_fwd_kernel( CU_ptr, CI_ptr, scale: tl.constexpr, - attend_current: tl.constexpr, batch_size, token_size, head_size: tl.constexpr, @@ -58,7 +57,7 @@ def stickbreaking_attn_fwd_kernel( eos = tl.load(CU_ptr + i_n + 1).to(tl.int32) seq_length = eos - bos else: - bos = 0 + bos = tl.full([], 0, dtype=tl.int32) seq_block_id = prog_id seq_length = token_size qk_scale = inv_log2 * scale @@ -105,7 +104,6 @@ def stickbreaking_attn_fwd_kernel( no_grad, acc_dtype, False, - attend_current=attend_current, is_compiling=is_compiling, ) @@ -138,7 +136,6 @@ def stickbreaking_attn_fwd_one_row_kernel( no_grad: tl.constexpr = False, acc_dtype: tl.constexpr = tl.float32, return_attention: tl.constexpr = False, - attend_current: tl.constexpr = False, is_compiling: tl.constexpr = False, ): block_start_offset = BLOCK_M * seq_block_id @@ -201,7 +198,6 @@ def stickbreaking_attn_fwd_one_row_kernel( on_band, ALLOW_TF32, backward=False, - attend_current=attend_current, is_compiling=is_compiling, use_cumsum=False, ) @@ -252,7 +248,6 @@ def stickbreaking_attn_bwd_kernel( BLOCK_N: tl.constexpr, acc_dtype: tl.constexpr = tl.float32, is_compiling: tl.constexpr = False, - attend_current: tl.constexpr = False, IS_VARLEN: tl.constexpr = False, ): tl.static_assert(BLOCK_M % BLOCK_N == 0) @@ -280,20 +275,24 @@ def stickbreaking_attn_bwd_kernel( head_id = head_pid seq_prog_id = seq_block_id - batch_offset = batch_id * token_size - DO_head_seq_ptr = DO_ptr + ((batch_offset + bos) * num_heads + head_id) * head_size - DR_head_seq_ptr = DR_ptr + ((batch_offset + bos) * num_heads + head_id) - A_head_seq_ptr = A_ptr + ((batch_offset + bos) * num_heads + head_id) - Q_head_seq_ptr = Q_ptr + ((batch_offset + bos) * num_heads + head_id) * head_size - K_head_seq_ptr = K_ptr + ((batch_offset + bos) * num_heads + head_id) * head_size - V_head_seq_ptr = V_ptr + ((batch_offset + bos) * num_heads + head_id) * head_size - DQ_head_seq_ptr = DQ_ptr + ((batch_offset + bos) * num_heads + head_id) * head_size - DK_head_seq_ptr = DK_ptr + ( - seq_prog_id * batch_size * token_size * num_heads + (batch_offset + bos) * num_heads + head_id - ) * head_size - DV_head_seq_ptr = DV_ptr + ( - seq_prog_id * batch_size * token_size * num_heads + (batch_offset + bos) * num_heads + head_id - ) * head_size + batch_id_i64 = batch_id.to(tl.int64) + head_id_i64 = head_id.to(tl.int64) + seq_prog_id_i64 = seq_prog_id.to(tl.int64) + bos_i64 = bos.to(tl.int64) + + batch_offset = batch_id_i64 * token_size + head_offset = (batch_offset + bos_i64) * num_heads + head_id_i64 + block_offset = seq_prog_id_i64 * batch_size * token_size * num_heads + + DO_head_seq_ptr = DO_ptr + head_offset * head_size + DR_head_seq_ptr = DR_ptr + head_offset + A_head_seq_ptr = A_ptr + head_offset + Q_head_seq_ptr = Q_ptr + head_offset * head_size + K_head_seq_ptr = K_ptr + head_offset * head_size + V_head_seq_ptr = V_ptr + head_offset * head_size + DQ_head_seq_ptr = DQ_ptr + head_offset * head_size + DK_head_seq_ptr = DK_ptr + (block_offset + head_offset) * head_size + DV_head_seq_ptr = DV_ptr + (block_offset + head_offset) * head_size stickbreaking_attn_bwd_one_row_kernel( seq_prog_id, @@ -325,7 +324,6 @@ def stickbreaking_attn_bwd_kernel( BLOCK_N, acc_dtype, is_compiling=is_compiling, - attend_current=attend_current, ) @@ -357,7 +355,6 @@ def compute_block( on_band: tl.constexpr, ALLOW_TF32: tl.constexpr, backward: tl.constexpr, - attend_current: tl.constexpr = False, use_cumsum: tl.constexpr = False, is_compiling: tl.constexpr = False, ): @@ -365,10 +362,7 @@ def compute_block( log_om_beta = -softplus(qk, is_compiling=is_compiling) if on_band: - if attend_current: - block_mask = M_blk_idxs[:, None] >= N_blk_idxs[None, :] - else: - block_mask = M_blk_idxs[:, None] > N_blk_idxs[None, :] + block_mask = M_blk_idxs[:, None] > N_blk_idxs[None, :] log_om_beta = tl.where(block_mask, log_om_beta, 0.0) if backward: neg_log_acc -= tl.sum(log_om_beta, axis=1) @@ -429,7 +423,6 @@ def stickbreaking_attn_bwd_one_row_kernel( BLOCK_N: tl.constexpr, acc_dtype: tl.constexpr = tl.float32, is_compiling: tl.constexpr = False, - attend_current: tl.constexpr = False, ): block_start_offset = BLOCK_M * seq_prog_id M_blk_idxs = block_start_offset + M_range @@ -508,7 +501,6 @@ def stickbreaking_attn_bwd_one_row_kernel( cm, on_band, ALLOW_TF32, - attend_current=attend_current, backward=True, is_compiling=is_compiling, ) @@ -554,7 +546,6 @@ def stickbreaking_attn_fwd( k: torch.Tensor, v: torch.Tensor, scale: float, - attend_current: bool, cu_seqlens: torch.LongTensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ @@ -580,6 +571,12 @@ def stickbreaking_attn_fwd( grid = (1, num_heads, num_seq_blocks) BLOCK_D = triton.next_power_of_2(dim_size) + NO_M_MASK = (token_size % BLOCK_M) == 0 + NO_N_MASK = (token_size % BLOCK_N) == 0 + if cu_seqlens is not None: + NO_M_MASK = False + NO_N_MASK = False + stickbreaking_attn_fwd_kernel[grid]( q, k, @@ -590,15 +587,14 @@ def stickbreaking_attn_fwd( CU_ptr=cu_seqlens if cu_seqlens is not None else q, CI_ptr=CI if CI is not None else q, scale=scale, - attend_current=attend_current, batch_size=batch_size, token_size=token_size, head_size=dim_size, num_heads=num_heads, BLOCK_D=BLOCK_D, NO_D_MASK=dim_size == BLOCK_D, - NO_M_MASK=(token_size % BLOCK_M) == 0, - NO_N_MASK=(token_size % BLOCK_N) == 0, + NO_M_MASK=NO_M_MASK, + NO_N_MASK=NO_N_MASK, ALLOW_TF32=ALLOW_TF32, inv_log2=inv_log2, BLOCK_M=BLOCK_M, @@ -619,7 +615,6 @@ def stickbreaking_attn_bwd( v: torch.Tensor, neg_log_acc: torch.Tensor, scale: float, - attend_current: bool, cu_seqlens: torch.LongTensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: batch_size, token_size, num_heads, dim_size = q.size() @@ -638,6 +633,13 @@ def stickbreaking_attn_bwd( dv = torch.zeros((M_count, batch_size, token_size, num_heads, dim_size), dtype=v.dtype, device=v.device) BLOCK_D = triton.next_power_of_2(dim_size) + + NO_M_MASK = (token_size % BLOCK_M) == 0 + NO_N_MASK = (token_size % BLOCK_N) == 0 + if cu_seqlens is not None: + NO_M_MASK = False + NO_N_MASK = False + stickbreaking_attn_bwd_kernel[grid]( do, dr, @@ -651,7 +653,6 @@ def stickbreaking_attn_bwd( CU_ptr=cu_seqlens if cu_seqlens is not None else q, CI_ptr=CI if CI is not None else q, scale=scale, - attend_current=attend_current, batch_size=batch_size, token_size=token_size, head_size=dim_size, @@ -660,8 +661,8 @@ def stickbreaking_attn_bwd( BLOCK_N=BLOCK_N, BLOCK_D=BLOCK_D, NO_D_MASK=dim_size == BLOCK_D, - NO_M_MASK=(token_size % BLOCK_M) == 0, - NO_N_MASK=(token_size % BLOCK_N) == 0, + NO_M_MASK=NO_M_MASK, + NO_N_MASK=NO_N_MASK, ALLOW_TF32=ALLOW_TF32, inv_log2=inv_log2, acc_dtype=tl.float32, @@ -683,21 +684,19 @@ def forward( k: torch.Tensor, v: torch.Tensor, scale: float, - attend_current: bool = False, cu_seqlens: torch.LongTensor | None = None, ): - o, rem, neg_log_acc = stickbreaking_attn_fwd(q, k, v, scale, attend_current, cu_seqlens) + o, rem, neg_log_acc = stickbreaking_attn_fwd(q, k, v, scale, cu_seqlens) ctx.save_for_backward(q, k, v, neg_log_acc) ctx.scale = scale - ctx.attend_current = attend_current ctx.cu_seqlens = cu_seqlens return o, rem @staticmethod def backward(ctx, do: torch.Tensor, drem: torch.Tensor): q, k, v, neg_log_acc = ctx.saved_tensors - dq, dk, dv = stickbreaking_attn_bwd(do, drem, q, k, v, neg_log_acc, ctx.scale, ctx.attend_current, ctx.cu_seqlens) - return dq, dk, dv, None, None, None + dq, dk, dv = stickbreaking_attn_bwd(do, drem, q, k, v, neg_log_acc, ctx.scale, ctx.cu_seqlens) + return dq, dk, dv, None, None def parallel_stickbreaking_attn( @@ -705,12 +704,11 @@ def parallel_stickbreaking_attn( k: torch.Tensor, v: torch.Tensor, scale: float | None = None, - attend_current: bool = False, cu_seqlens: torch.LongTensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if scale is None: scale = q.shape[-1] ** -0.5 - return StickBreakingAttentionFunction.apply(q, k, v, scale, attend_current, cu_seqlens) + return StickBreakingAttentionFunction.apply(q, k, v, scale, cu_seqlens) __all__ = [ diff --git a/sba_code/.gitignore b/sba_code/.gitignore new file mode 100644 index 000000000..82f927558 --- /dev/null +++ b/sba_code/.gitignore @@ -0,0 +1,162 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/sba_code/LICENSE b/sba_code/LICENSE new file mode 100644 index 000000000..261eeb9e9 --- /dev/null +++ b/sba_code/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/sba_code/README.md b/sba_code/README.md new file mode 100644 index 000000000..ee44de9d3 --- /dev/null +++ b/sba_code/README.md @@ -0,0 +1,32 @@ +# Stick-breaking Attention Implementation +Triton-based implementation of Stick-breaking Attention on GPUs. +This implementation is for variable length . +You can find the paper [here](https://arxiv.org/abs/2410.17980) + +## Installation +```sh +# Install editable. This will allow you to modify stickbreaking in this directory. +pip install -e . +# Check all is working well. +pytest -x tests +``` +### Usage +#### Variable Length Attention +Each mini-batch consists of concatenated sequences of different lengths. + +`sb_attn_varlen` implements the counterpart to Flash Attention's +[`flash_attn_varlen_func`](https://github.com/Dao-AILab/flash-attention/blob/0dfb28174333d9eefb7c1dd4292690a8458d1e89/flash_attn/flash_attn_interface.py#L1360). +Assuming we have an input batch that concatenates all documents/sequences into a long array, and the corresponding +sequence lengths in the batch in an array `lengths`. +Then we can compute the cu_seqlens and pass that to `sb_attn_varlen`: +```python +import torch +from stickbreaking_attention.sb_varlen import sb_attn_varlen +# lengths: batch_size, +total_length = torch.sum(lengths) +# q, k, v: num_heads, total_length, head_dima +cu_seqlens = torch.cumsum(lengths) +o, rem = sb_attn_varlen(q, k, v, cu_seqlens, zero_start=False) +``` + +Enjoy! diff --git a/sba_code/benchmarks/attn.py b/sba_code/benchmarks/attn.py new file mode 100644 index 000000000..cc9458416 --- /dev/null +++ b/sba_code/benchmarks/attn.py @@ -0,0 +1,96 @@ +import torch +import pytest +import math +from torch.nn import functional as F +from stickbreaking_attention.sb_attn import sb_attn +import triton +from flash_attn import flash_attn_func +from flash_attn.flash_attn_triton import flash_attn_func as triton_flash_attn_func +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb, rotate_half +from transformers import set_seed + + +def tri_fwdbwd(do, q, k, v): + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + o, rem = sb_attn(q, k, v, inv_temp=1 / math.sqrt(q.size(-1))) + o = o.permute(0, 2, 1, 3) + # o = o + rem[..., None] * v + return o + +def flash_fwdbwd(rope, position_ids, do, q, k, v): + cos, sin = rope(v, position_ids) + cos = cos.unsqueeze(-2) + sin = sin.unsqueeze(-2) + q = (q * cos) + (rotate_half(q) * sin) + k = (k * cos) + (rotate_half(k) * sin) + o = flash_attn_func(q, k, v, causal=True) + # o = o.permute(0, 2, 1, 3) + return o + +def triton_flash_fwdbwd(rope, position_ids, do, q, k, v): + cos, sin = rope(v, position_ids) + cos = cos.unsqueeze(-2) + sin = sin.unsqueeze(-2) + q = (q * cos) + (rotate_half(q) * sin) + k = (k * cos) + (rotate_half(k) * sin) + o = triton_flash_attn_func(q, k, v, None, True) + # o = o.permute(0, 2, 1, 3) + return o + + +providers = [ + ("triton", "Stickbreaking", ("blue", "-")), + ("flash", "Flash Attention", ("green", "-")), + # ("triton_flash", "Triton Flash", ("red", "-")), # triton flash not working +] +@triton.testing.perf_report([ + triton.testing.Benchmark( + x_names=["length"], + x_vals=[4096, 2 * 4096, 3 * 4096, 4 * 4096], + line_arg="provider", + line_vals=[x[0] for x in providers], + line_names=[x[1] for x in providers], + styles=[x[2] for x in providers], + ylabel="ms", + plot_name=f"triton v torch", + args={"batch_size": 4, "num_heads": 12, "head_dim": 128, "dtype": torch.bfloat16, "bwd": True} + ) +]) +def benchmark_attn(batch_size, num_heads, head_dim, length, dtype, provider, bwd): + device = torch.device('cuda:0') + set_seed(1337) + warmup = 100 + rep = 1000 + + q = torch.randn((batch_size, length, num_heads, head_dim), device=device, dtype=dtype) + k = torch.randn((batch_size, length, num_heads, head_dim), device=device, dtype=dtype) + v = torch.randn((batch_size, length, num_heads, head_dim), device=device, dtype=dtype) + q.requires_grad_() + k.requires_grad_() + v.requires_grad_() + do = torch.randn((batch_size, length, num_heads, head_dim), device=device, dtype=dtype) + position_ids = torch.arange(q.size(1), device=device, dtype=torch.int32)[None, :] + if provider == "triton": + fun = lambda: tri_fwdbwd(do, q, k, v) + elif provider == "flash": + rope = LlamaRotaryEmbedding(dim=head_dim).to(device) + fun = lambda: flash_fwdbwd(rope, position_ids, do, q, k, v) + elif provider == "triton_flash": + rope = LlamaRotaryEmbedding(dim=head_dim).to(device) + fun = lambda: triton_flash_fwdbwd(rope, position_ids, do, q, k, v) + + if bwd: + def fun_(): + o = fun() + dq, dk, dv = torch.autograd.grad(o, inputs=(q, k, v), grad_outputs=do) + + return triton.testing.do_bench(fun_, warmup=warmup, rep=rep) + else: + return triton.testing.do_bench(fun, warmup=warmup, rep=rep) + + + +if __name__ == "__main__": + benchmark_attn.run(save_path=None, print_data=True) diff --git a/sba_code/benchmarks/varlen.py b/sba_code/benchmarks/varlen.py new file mode 100644 index 000000000..e04c805a4 --- /dev/null +++ b/sba_code/benchmarks/varlen.py @@ -0,0 +1,129 @@ +import torch +import pytest +import math +from torch.nn import functional as F +from stickbreaking_attention.sb_varlen import sb_attn_varlen +import triton +from flash_attn import flash_attn_varlen_func +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb, rotate_half +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers import set_seed +from stickbreaking_attention.sb_ref import stickbreaking + + + +def ref_fwd(q, k, v, lengths): + q = q.permute(1, 0, 2) + k = k.permute(1, 0, 2) + v = v.permute(1, 0, 2) + splits = list(lengths.cpu().numpy()) + max_len = max(splits) + cm = torch.ones(max_len, max_len).tril(-1).to(q) + mask = torch.ones(max_len, max_len).triu(0).cuda().bool() + outputs = [] + for q_chunk, k_chunk, v_chunk in zip(q.split(splits, 1), k.split(splits, 1), v.split(splits, 1)): + len = q_chunk.size(1) + o, rem = stickbreaking( + q_chunk[None, :], + k_chunk[None, :], + v_chunk[None, :], + mask[:len, :len], cm[:len, :len] + ) + + # o = o + rem[..., None] * v_chunk[None] + outputs.append(o[0]) + return torch.cat(outputs, 1) + +def ref_fwdbwd(do, q, k, v, lengths): + o = ref_fwd(q, k, v, lengths) + return o + + +def tri_fwdbwd(do, q, k, v, lengths): + q = q.permute(1, 0, 2) + k = k.permute(1, 0, 2) + v = v.permute(1, 0, 2) + cu_seqlens = torch.cumsum(lengths, dim=-1) + o, rem = sb_attn_varlen(q, k, v, + cu_seqlens=cu_seqlens, + max_seqlens=max(lengths).item(), + inv_temp=1 / math.sqrt(q.size(-1)), + zero_start=False) + # o = o + rem[..., None] * v + return o + +def flash_fwdbwd(rope, position_ids, do, q, k, v, lengths): + cos, sin = rope(v, position_ids) + q = (q * cos) + (rotate_half(q) * sin) + k = (k * cos) + (rotate_half(k) * sin) + lengths = lengths.to(torch.int32) + cu_seqlens = torch.cumsum(lengths, dim=-1) + cu_seqlens = F.pad(cu_seqlens, (1, 0)).to(torch.int32) + max_len = torch.max(lengths) + o = flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_len, + max_seqlen_k=max_len, + causal=True + ) + o = o.permute(1, 0, 2) + return o + + +providers = [ + # ("reference", "Stickbreaking (ref.)", ("red", "-")), + ("triton", "Stickbreaking", ("blue", "-")), + ("flash", "Flash Attention", ("green", "-")), +] +@triton.testing.perf_report([ + triton.testing.Benchmark( + x_names=["length"], + x_vals=[4096, 2 * 4096, 3 * 4096, 4 * 4096], + line_arg="provider", + line_vals=[x[0] for x in providers], + line_names=[x[1] for x in providers], + styles=[x[2] for x in providers], + ylabel="ms", + plot_name=f"triton v torch", + args={"batch_size": 4, "num_heads": 12, "head_dim": 128, "dtype": torch.bfloat16, "bwd": True} + ) +]) +def benchmark_varlen(batch_size, num_heads, head_dim, length, dtype, provider, bwd): + device = torch.device('cuda:0') + set_seed(1337) + lengths = torch.randint(length, length + 1, (batch_size,)).to(device=device, dtype=torch.int32) + total_length = lengths.sum() + warmup = 100 + rep = 1000 + + q = torch.randn((total_length, num_heads, head_dim), device=device, dtype=dtype) + k = torch.randn((total_length, num_heads, head_dim), device=device, dtype=dtype) + v = torch.randn((total_length, num_heads, head_dim), device=device, dtype=dtype) + q.requires_grad_() + k.requires_grad_() + v.requires_grad_() + do = torch.randn((num_heads, total_length, head_dim), device=device, dtype=dtype) + position_ids = torch.arange(q.size(1), device=device, dtype=torch.int32)[None, :] + + if provider== "reference": + fun = lambda: ref_fwdbwd(do, q, k, v, lengths) + elif provider == "triton": + fun = lambda: tri_fwdbwd(do, q, k, v, lengths) + elif provider == "flash": + config = LlamaConfig(max_position_embeddings=length) + rope = LlamaRotaryEmbedding(config).to(device) + fun = lambda: flash_fwdbwd(rope, position_ids, do, q, k, v, lengths) + if bwd: + def fun_(): + o = fun() + dq, dk, dv = torch.autograd.grad(o, inputs=(q, k, v), grad_outputs=do) + return triton.testing.do_bench(fun_, warmup=warmup, rep=rep) + else: + return triton.testing.do_bench(fun, warmup=warmup, rep=rep) + + + +if __name__ == "__main__": + benchmark_varlen.run(save_path=None, print_data=True) diff --git a/sba_code/load_model_with_dolomite_demo.py b/sba_code/load_model_with_dolomite_demo.py new file mode 100644 index 000000000..267918dfe --- /dev/null +++ b/sba_code/load_model_with_dolomite_demo.py @@ -0,0 +1,7 @@ +import transformers +from dolomite_engine import hf_models + +if __name__ == "__main__": + hf_models = transformers.AutoModelForCausalLM.from_pretrained( + 'shawntan/stickbreaking-3b', + ) diff --git a/sba_code/setup.py b/sba_code/setup.py new file mode 100644 index 000000000..809fff7c8 --- /dev/null +++ b/sba_code/setup.py @@ -0,0 +1,29 @@ +import os +from setuptools import setup, find_packages + +def read(fname): + return open(os.path.join(os.path.dirname(__file__), fname)).read() + +setup( + name = "stickbreaking_attention", + version = "0.0.0", + author = "Shawn Tan", + author_email = "shawntan@ibm.com", + description = "Triton implementation of Stick-breaking attention", + license = "Apache License", + keywords = "triton pytorch llm stickbreaking attention", + url = "https://github.com/shawntan/scattermoe", + packages=find_packages(), + long_description=read('README.md'), + python_requires='>=3.10.10', + install_requires=[ + 'torch', + 'triton', + ], + tests_require=['pytest', 'numpy'], + classifiers=[ + "Development Status :: 1 - Planning", + "License :: OSI Approved :: Apache Software License", + ], +) + diff --git a/sba_code/stickbreaking_attention/__init__.py b/sba_code/stickbreaking_attention/__init__.py new file mode 100644 index 000000000..e798ad2f8 --- /dev/null +++ b/sba_code/stickbreaking_attention/__init__.py @@ -0,0 +1,2 @@ +from .sb_attn import sb_attn +from .sb_varlen import sb_attn_varlen diff --git a/sba_code/stickbreaking_attention/sb_attn/__init__.py b/sba_code/stickbreaking_attention/sb_attn/__init__.py new file mode 100644 index 000000000..22d55b583 --- /dev/null +++ b/sba_code/stickbreaking_attention/sb_attn/__init__.py @@ -0,0 +1,64 @@ +import math + +import torch +import triton.language as tl + +from .sb_bwd import _bwd +from .sb_fwd import _fwd + + +FWD_BLOCK_M: tl.constexpr = 64 +FWD_BLOCK_N: tl.constexpr = 32 +BWD_BLOCK_M: tl.constexpr = 64 +BWD_BLOCK_N: tl.constexpr = 32 + + +class StickBreakingAttention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, inv_temp: float, + attend_current: bool = False): + no_grad = not ctx.needs_input_grad[0] + logit_scale = inv_temp + BLOCK_M = FWD_BLOCK_M + BLOCK_N = FWD_BLOCK_N + o, rem, neg_log_acc = _fwd( + q, k, v, logit_scale=inv_temp, no_grad=no_grad, + return_attention=False, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + attend_current=attend_current + ) + ctx.save_for_backward(q, k, v, neg_log_acc) + ctx.logit_scale = logit_scale + ctx.attend_current = attend_current + return o, rem + + @staticmethod + def backward(ctx, do: torch.Tensor, drem: torch.Tensor): + logit_scale = ctx.logit_scale + attend_current = ctx.attend_current + q, k, v, neg_log_acc = ctx.saved_tensors + BLOCK_M = BWD_BLOCK_M + BLOCK_N = BWD_BLOCK_N + dq, dk, dv = _bwd( + do, + drem, + q, + k, + v, + neg_log_acc, + logit_scale, + attend_current=attend_current, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + return dq, dk, dv, None, None + + +def sb_attn(q, k, v, inv_temp=None, zero_start=True, attend_current=False): + if inv_temp is None: + inv_temp = 1 / math.sqrt(q.size(-1)) + return sb_attn_(q, k, v, inv_temp, attend_current=attend_current) + + +def sb_attn_(q, k, v, inv_temp, attend_current): + return StickBreakingAttention.apply(q, k, v, inv_temp, attend_current) diff --git a/sba_code/stickbreaking_attention/sb_attn/sb_bwd.py b/sba_code/stickbreaking_attention/sb_attn/sb_bwd.py new file mode 100644 index 000000000..8b881c5a0 --- /dev/null +++ b/sba_code/stickbreaking_attention/sb_attn/sb_bwd.py @@ -0,0 +1,297 @@ +import torch +import triton +import triton.language as tl + +from ..utils import ALLOW_TF32, inv_log2, custom_op +from ..sb_varlen.sb_varlen_bwd import _backward_one_row +from ..sb_varlen.sb_varlen_fwd import compute_block, load_kv + + +def get_configs(): + return [triton.Config({}, num_stages=s, num_warps=w) for s in [8] for w in [4]] + + +@triton.autotune( + configs=get_configs(), + key=["token_size", "head_size"], +) +# reset_to_zero=["DK_ptr", "DV_ptr"]) +@triton.jit() +def _backward( + DO_ptr, + stride_dob, + stride_doh, + stride_dom: tl.constexpr, + stride_dod: tl.constexpr, + DR_ptr, + stride_drb, + stride_drh, + stride_drm: tl.constexpr, + A_ptr, + stride_ab, + stride_ah, + stride_am: tl.constexpr, + Q_ptr, + stride_qb, + stride_qh, + stride_qm: tl.constexpr, + stride_qd: tl.constexpr, + K_ptr, + stride_kb, + stride_kh, + stride_kn: tl.constexpr, + stride_kd: tl.constexpr, + V_ptr, + stride_vb, + stride_vh, + stride_vn: tl.constexpr, + stride_vd: tl.constexpr, + DQ_ptr, + stride_dqb, + stride_dqh, + stride_dqm: tl.constexpr, + stride_dqd: tl.constexpr, + DK_ptr, + stride_dkb, + stride_dkh, + stride_dkn: tl.constexpr, + stride_dkd: tl.constexpr, + DV_ptr, + stride_dvb, + stride_dvh, + stride_dvn: tl.constexpr, + stride_dvd: tl.constexpr, + KV_Lock_ptr, + KV_Count_ptr, + stride_kvb: tl.constexpr, + stride_kvl: tl.constexpr, + logit_scale, + batch_size, + token_size, + head_size: tl.constexpr, + num_heads: tl.constexpr, + BLOCK_D: tl.constexpr, + NO_D_MASK: tl.constexpr, + NO_M_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, + ALLOW_TF32: tl.constexpr, + inv_log2: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + acc_dtype: tl.constexpr = tl.float32, + is_compiling: tl.constexpr = False, + attend_current: tl.constexpr = False, +): + tl.static_assert(BLOCK_M % BLOCK_N == 0) + batch_id = tl.program_id(0) + head_pid = tl.program_id(1) + prog_id = tl.program_id(2) + # Universal stuff + qk_scale = inv_log2 * logit_scale + M_range = tl.arange(0, BLOCK_M) + N_range = tl.arange(0, BLOCK_N) + D_range = tl.arange(0, BLOCK_D) + D_mask = D_range < head_size + cm = tl.where(N_range[:, None] >= N_range[None, :], + 1.0, 0.0).to(Q_ptr.type.element_ty) + + head_id = head_pid + seq_prog_id = prog_id + seq_length = token_size + + DO_head_seq_ptr = DO_ptr + stride_dob * batch_id + stride_doh * head_id + DR_head_seq_ptr = DR_ptr + stride_drb * batch_id + stride_drh * head_id + A_head_seq_ptr = A_ptr + stride_ab * batch_id + stride_ah * head_id + Q_head_seq_ptr = Q_ptr + stride_qb * batch_id + stride_qh * head_id + K_head_seq_ptr = K_ptr + stride_kb * batch_id + stride_kh * head_id + V_head_seq_ptr = V_ptr + stride_vb * batch_id + stride_vh * head_id + DQ_head_seq_ptr = DQ_ptr + stride_dqb * batch_id + stride_dqh * head_id + DK_head_seq_ptr = DK_ptr + stride_dkb * batch_id + stride_dkh * head_id + DV_head_seq_ptr = DV_ptr + stride_dvb * batch_id + stride_dvh * head_id + KV_Lock_head_seq_ptr = KV_Lock_ptr + stride_kvb * batch_id + stride_kvl * head_id + KV_Count_head_seq_ptr = KV_Count_ptr + \ + stride_kvb * batch_id + stride_kvl * head_id + _backward_one_row( + seq_prog_id, + seq_length, + qk_scale, + M_range, + N_range, + D_range, + D_mask, + cm, + DO_head_seq_ptr, + stride_dom, + stride_dod, + DR_head_seq_ptr, + stride_drm, + A_head_seq_ptr, + stride_am, + Q_head_seq_ptr, + stride_qm, + stride_qd, + K_head_seq_ptr, + stride_kn, + stride_kd, + V_head_seq_ptr, + stride_vn, + stride_vd, + DQ_head_seq_ptr, + stride_dqm, + stride_dqd, + DK_head_seq_ptr, + stride_dkn, + stride_dkd, + DV_head_seq_ptr, + stride_dvn, + stride_dvd, + KV_Lock_head_seq_ptr, + KV_Count_head_seq_ptr, + logit_scale, + BLOCK_D, + NO_D_MASK, + NO_M_MASK, + ALLOW_TF32, + BLOCK_M, + BLOCK_N, + acc_dtype, + is_compiling=is_compiling, + attend_current=attend_current, + ) + + +def _bwd(do, dr, q, k, v, neg_log_acc, logit_scale, + attend_current=False, BLOCK_M=64, BLOCK_N=32): + batch_size, num_heads, token_size, dim_size = q.size() + M_count = triton.cdiv(token_size, BLOCK_M) + N_count = triton.cdiv(token_size, BLOCK_N) + + # dqdkdv = torch.zeros((batch_size, token_size, num_heads, 3 * dim_size), device=do.device, dtype=do.dtype) + # dqdkdv = dqdkdv.permute(0, 2, 1, 3) + # dq, dk, dv = dqdkdv.chunk(3, dim=-1) + dq = torch.zeros_like(q) + dk = torch.zeros_like(k, dtype=torch.bfloat16) + dv = torch.zeros_like(v, dtype=torch.bfloat16) + + M_count = triton.cdiv(token_size, BLOCK_M) + N_count = M_count * (BLOCK_M // BLOCK_N) + dkdv_lock = torch.zeros((batch_size, num_heads, N_count), + dtype=torch.int32, device=q.device) + dkdv_count = torch.zeros( + (batch_size, num_heads, N_count), dtype=torch.bool, device=q.device) + _compileable_backward( + do, + dr, + q, + k, + v, + neg_log_acc, + logit_scale, + attend_current, + BLOCK_M, + BLOCK_N, + batch_size, + num_heads, + token_size, + dim_size, + M_count, + N_count, + dq, + dk, + dv, + dkdv_lock, + dkdv_count, + ) + return dq, dk, dv + + +@custom_op("attn_bwd", mutates_args={"dq", "dk", "dv", "dkdv_lock", "dkdv_count"}) +def _compileable_backward( + do: torch.Tensor, + dr: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + neg_log_acc: torch.Tensor, + logit_scale: float, + attend_current: bool, + BLOCK_M: int, + BLOCK_N: int, + batch_size: int, + num_heads: int, + token_size: int, + dim_size: int, + M_count: int, + N_count: int, + dq: torch.Tensor, + dk: torch.Tensor, + dv: torch.Tensor, + dkdv_lock: torch.Tensor, + dkdv_count: torch.Tensor, +) -> None: + BLOCK_D = triton.next_power_of_2(dim_size) + _backward[batch_size, num_heads, M_count]( + do, + do.stride(0), + do.stride(1), + do.stride(2), + do.stride(3), + dr, + dr.stride(0), + dr.stride(1), + dr.stride(2), + neg_log_acc, + neg_log_acc.stride(0), + neg_log_acc.stride(1), + neg_log_acc.stride(2), + q, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k, + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v, + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + dq, + dq.stride(0), + dq.stride(1), + dq.stride(2), + dq.stride(3), + dk, + dk.stride(0), + dk.stride(1), + dk.stride(2), + dk.stride(3), + dv, + dv.stride(0), + dv.stride(1), + dv.stride(2), + dv.stride(3), + dkdv_lock, + dkdv_count, + num_heads * N_count, + N_count, + logit_scale=logit_scale, + attend_current=attend_current, + batch_size=batch_size, + token_size=token_size, + head_size=dim_size, + num_heads=num_heads, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_D=BLOCK_D, + NO_D_MASK=BLOCK_D == dim_size, + NO_M_MASK=(token_size % BLOCK_M) == 0, + NO_N_MASK=(token_size % BLOCK_N) == 0, + ALLOW_TF32=ALLOW_TF32, + inv_log2=inv_log2, + acc_dtype=tl.float32, + is_compiling=False, + ) diff --git a/sba_code/stickbreaking_attention/sb_attn/sb_fwd.py b/sba_code/stickbreaking_attention/sb_attn/sb_fwd.py new file mode 100644 index 000000000..b9f04cc2d --- /dev/null +++ b/sba_code/stickbreaking_attention/sb_attn/sb_fwd.py @@ -0,0 +1,253 @@ +import torch +import triton +import triton.language as tl + +from ..utils import ALLOW_TF32, inv_log2, custom_op +from ..sb_varlen.sb_varlen_fwd import _forward_one_row +from ..sb_varlen.softplus import softplus + + + +def get_configs(): + return [triton.Config({}, num_stages=s, num_warps=w) for s in [4] for w in [4]] + + +@triton.autotune(configs=get_configs(), key=["token_size", "head_size"]) +@triton.jit +def _forward( + Q_ptr, + stride_qb, + stride_qh, + stride_qm: tl.constexpr, + stride_qd: tl.constexpr, + K_ptr, + stride_kb, + stride_kh, + stride_kn: tl.constexpr, + stride_kd: tl.constexpr, + V_ptr, + stride_vb, + stride_vh, + stride_vn: tl.constexpr, + stride_vd: tl.constexpr, + O_ptr, + stride_ob, + stride_oh, + stride_om: tl.constexpr, + stride_od: tl.constexpr, + R_ptr, + stride_rb, + stride_rh, + stride_rm: tl.constexpr, + A_ptr, + stride_ab, + stride_ah, + stride_am: tl.constexpr, + W_ptr, + stride_wb, + stride_wh, + stride_wm, + stride_wn, + logit_scale: tl.constexpr, + attend_current: tl.constexpr, + batch_size, + token_size, + head_size: tl.constexpr, + num_heads: tl.constexpr, + BLOCK_D: tl.constexpr, + NO_D_MASK: tl.constexpr, + NO_M_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, + ALLOW_TF32: tl.constexpr, + inv_log2: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + no_grad: tl.constexpr = False, + acc_dtype: tl.constexpr = tl.float32, + return_attention: tl.constexpr = False, + is_compiling: tl.constexpr = False, +): + tl.static_assert(BLOCK_M % BLOCK_N == 0) + batch_id = tl.program_id(0) + head_pid = tl.program_id(1) + prog_id = tl.program_id(2) + tl.num_programs(2) + seq_length = token_size + # Universal stuff + qk_scale = inv_log2 * logit_scale + M_range = tl.arange(0, BLOCK_M) + N_range = tl.arange(0, BLOCK_N) + D_range = tl.arange(0, BLOCK_D) + D_mask = D_range < head_size + cm = tl.where(N_range[:, None] >= N_range[None, :], + 1.0, 0.0).to(Q_ptr.type.element_ty) + + # First head block + head_id = head_pid + seq_prog_id = prog_id + # tl.store(pid_debug_ptr + head_id * tl.num_programs(1) + prog_id_start_offset + seq_prog_id, pid) + Q_head_seq_ptr = Q_ptr + stride_qb * batch_id + stride_qh * head_id + K_head_seq_ptr = K_ptr + stride_kb * batch_id + stride_kh * head_id + V_head_seq_ptr = V_ptr + stride_vb * batch_id + stride_vh * head_id + O_head_seq_ptr = O_ptr + stride_ob * batch_id + stride_oh * head_id + R_head_seq_ptr = R_ptr + stride_rb * batch_id + stride_rh * head_id + A_head_seq_ptr = A_ptr + stride_ab * batch_id + stride_ah * head_id + W_head_seq_ptr = W_ptr + stride_wb * batch_id + stride_wh * head_id + _forward_one_row( + seq_prog_id, + seq_length, + qk_scale, + M_range, + N_range, + D_range, + D_mask, + cm, + Q_head_seq_ptr, + stride_qm, + stride_qd, + K_head_seq_ptr, + stride_kn, + stride_kd, + V_head_seq_ptr, + stride_vn, + stride_vd, + O_head_seq_ptr, + stride_om, + stride_od, + R_head_seq_ptr, + stride_rm, + A_head_seq_ptr, + stride_am, + W_head_seq_ptr, + stride_wm, + stride_wn, + BLOCK_D, + NO_D_MASK, + NO_M_MASK, + NO_N_MASK, + ALLOW_TF32, + BLOCK_M, + BLOCK_N, + no_grad, + acc_dtype, + return_attention, + attend_current=attend_current, + is_compiling=is_compiling, + ) + + +def _fwd(q, k, v, logit_scale, + attend_current=False, + no_grad=False, return_attention=False, + BLOCK_M: int = 64, BLOCK_N: int = 32): + batch_size, num_heads, token_size, dim_size = q.size() + o = torch.empty_like(q) + rem = torch.zeros_like(q[:, :, :, 0], device=q.device) + neg_log_acc = torch.zeros_like(rem, device=q.device, dtype=torch.float32) + if return_attention: + W = torch.full((batch_size, num_heads, token_size, token_size), + 0.0, dtype=torch.float32, device=q.device) + else: + W = torch.empty((1, 1, 1, 1), device=q.device) + _compileable_fwd( + q, + k, + v, + logit_scale, + no_grad, + return_attention, + BLOCK_M, + BLOCK_N, + batch_size, + num_heads, + token_size, + dim_size, + o, + rem, + neg_log_acc, + W, + attend_current=attend_current, + ) + if return_attention: + return o, rem, neg_log_acc, W + else: + return o, rem, neg_log_acc + + +@custom_op("attn_fwd", mutates_args={"o", "rem", "neg_log_acc", "W"}) +def _compileable_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + logit_scale: float, + no_grad: bool, + return_attention: bool, + BLOCK_M: int, + BLOCK_N: int, + batch_size: int, + num_heads: int, + token_size: int, + dim_size: int, + o: torch.Tensor, + rem: torch.Tensor, + neg_log_acc: torch.Tensor, + W: torch.Tensor, + attend_current: bool, +) -> None: + num_folded_heads = num_heads + num_seq_blocks = triton.cdiv(token_size, BLOCK_M) + BLOCK_D = triton.next_power_of_2(dim_size) + grid = (batch_size, num_folded_heads, num_seq_blocks) + _forward[grid]( + q, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k, + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v, + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + o, + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + rem, + rem.stride(0), + rem.stride(1), + rem.stride(2), + neg_log_acc, + neg_log_acc.stride(0), + neg_log_acc.stride(1), + neg_log_acc.stride(2), + W, + W.stride(0), + W.stride(1), + W.stride(2), + W.stride(3), + logit_scale=logit_scale, + batch_size=batch_size, + token_size=token_size, + head_size=dim_size, + num_heads=num_heads, + no_grad=no_grad, + attend_current=attend_current, + BLOCK_D=BLOCK_D, + NO_D_MASK=BLOCK_D == dim_size, + NO_M_MASK=(token_size % BLOCK_M) == 0, + NO_N_MASK=(token_size % BLOCK_N) == 0, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ALLOW_TF32=ALLOW_TF32, + inv_log2=inv_log2, + return_attention=return_attention, + acc_dtype=tl.float32, + is_compiling=False, + ) diff --git a/sba_code/stickbreaking_attention/sb_ref.py b/sba_code/stickbreaking_attention/sb_ref.py new file mode 100644 index 000000000..7e6dc12f2 --- /dev/null +++ b/sba_code/stickbreaking_attention/sb_ref.py @@ -0,0 +1,25 @@ +import math + +import torch +from torch.nn import functional as F + + +# for reference +def stickbreaking(q, k, v, mask, cum_weight): + """ + Stick-breaking attention weights. + """ + logits = (q @ k.transpose(-1, -2)) / math.sqrt(q.shape[-1]) + + original_dtype = logits.dtype + + logits = logits.float() + log_z = F.logsigmoid(logits).masked_fill(mask, -1e5).to(original_dtype) + + log_beta = F.logsigmoid(-logits).masked_fill(mask, 0).to(original_dtype) + + re_cum_log_beta = torch.einsum( + "bhij,jk->bhik", log_beta, cum_weight.to(log_beta)) + log_att = log_z + re_cum_log_beta + att = log_att.exp() + return att @ v, 1 - att.sum(dim=-1) diff --git a/sba_code/stickbreaking_attention/sb_varlen/__init__.py b/sba_code/stickbreaking_attention/sb_varlen/__init__.py new file mode 100644 index 000000000..9a7096504 --- /dev/null +++ b/sba_code/stickbreaking_attention/sb_varlen/__init__.py @@ -0,0 +1,82 @@ +from .sb_varlen_fwd import varlen_fwd +from .sb_varlen_bwd import varlen_bwd +import math + +import torch +import triton.language as tl +from torch.nn import functional as F + + +FWD_BLOCK_M: tl.constexpr = 64 +FWD_BLOCK_N: tl.constexpr = 32 +BWD_BLOCK_M: tl.constexpr = 64 +BWD_BLOCK_N: tl.constexpr = 32 + + +def calculate_programs_needed(cu_seqlens: torch.Tensor, BLOCK_SIZE): + lens = cu_seqlens.clone() + lens[1:] -= cu_seqlens[:-1] + seq_num_programs = ((lens - 1) // BLOCK_SIZE) + 1 + seq_program_offsets = torch.cumsum(seq_num_programs, dim=0) + return seq_program_offsets + + +class StickBreakingAttention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, cu_seqlens, max_seqlens, inv_temp, attend_current): + no_grad = not ctx.needs_input_grad[0] + logit_scale = inv_temp + o, rem, neg_log_acc = varlen_fwd( + q, + k, + v, + cu_seqlens, + max_seqlens, + logit_scale=inv_temp, + attend_current=attend_current, + no_grad=no_grad, + BLOCK_M=FWD_BLOCK_M, + BLOCK_N=FWD_BLOCK_N, + ) + ctx.save_for_backward(q, k, v, neg_log_acc, cu_seqlens) + ctx.logit_scale = logit_scale + ctx.max_seqlens = max_seqlens + ctx.attend_current = attend_current + return o, rem + + @staticmethod + def backward(ctx, do, drem): + logit_scale = ctx.logit_scale + max_seqlens = ctx.max_seqlens + attend_current = ctx.attend_current + q, k, v, neg_log_acc, cu_seqlens = ctx.saved_tensors + dq, dk, dv = varlen_bwd( + do, + drem, + q, + k, + v, + cu_seqlens, + max_seqlens, + neg_log_acc, + logit_scale, + attend_current=attend_current, + BLOCK_M=BWD_BLOCK_M, + BLOCK_N=BWD_BLOCK_N, + ) + return dq, dk, dv, None, None, None, None + + +def sb_attn_varlen(q, k, v, cu_seqlens, max_seqlens, inv_temp=None, zero_start=True, attend_current=False): + if zero_start: + assert cu_seqlens[0] == 0 + cu_seqlens = cu_seqlens[1:] + if inv_temp is None: + inv_temp = 1 / math.sqrt(q.size(-1)) + + return sb_attn_varlen_(q, k, v, inv_temp, cu_seqlens, max_seqlens, attend_current) + + +def sb_attn_varlen_(q, k, v, inv_temp, cu_seqlens, max_seqlens, attend_current): + return StickBreakingAttention.apply(q, k, v, cu_seqlens, max_seqlens, inv_temp, attend_current) diff --git a/sba_code/stickbreaking_attention/sb_varlen/sb_varlen_bwd.py b/sba_code/stickbreaking_attention/sb_varlen/sb_varlen_bwd.py new file mode 100644 index 000000000..1812b4788 --- /dev/null +++ b/sba_code/stickbreaking_attention/sb_varlen/sb_varlen_bwd.py @@ -0,0 +1,641 @@ +import math + +import torch +import triton +import triton.language as tl + +from ..utils import ALLOW_TF32, inv_log2 +from .sb_varlen_fwd import compute_block, load_kv + +from ..utils import custom_op + +@triton.jit +def locked_add(Lock_ptr, Count_ptr, A_ptrs, a, B_ptrs, b, N_mask, NO_N_MASK, D_mask, NO_D_MASK: tl.constexpr, + EVICTION_POLICY: tl.constexpr=tl.constexpr("")): + while tl.atomic_cas(Lock_ptr, 0, 1) == 1: + pass + # tl.device_print("Start locked add.") + count = tl.load(Count_ptr, eviction_policy=EVICTION_POLICY) + if NO_D_MASK: + if NO_N_MASK: + if count == 0: + tl.store(Count_ptr, 1, eviction_policy=EVICTION_POLICY) + else: + a += tl.load(A_ptrs, eviction_policy=EVICTION_POLICY) + b += tl.load(B_ptrs, eviction_policy=EVICTION_POLICY) + tl.store(A_ptrs, a, eviction_policy=EVICTION_POLICY) + tl.store(B_ptrs, b, eviction_policy=EVICTION_POLICY) + + else: + if count == 0: + tl.store(Count_ptr, 1, eviction_policy=EVICTION_POLICY) + else: + a += tl.load(A_ptrs, + mask=N_mask[:, None], eviction_policy=EVICTION_POLICY) + b += tl.load(B_ptrs, + mask=N_mask[:, None], eviction_policy=EVICTION_POLICY) + tl.store(A_ptrs, a, mask=N_mask[:, None], + eviction_policy=EVICTION_POLICY) + tl.store(B_ptrs, b, mask=N_mask[:, None], + eviction_policy=EVICTION_POLICY) + + else: + # if True: # TODO delete + mask = N_mask[:, None] & D_mask[None, :] + if count == 0: + tl.store(Count_ptr, 1, eviction_policy=EVICTION_POLICY) + else: + a += tl.load(A_ptrs, mask=mask, eviction_policy=EVICTION_POLICY) + b += tl.load(B_ptrs, mask=mask, eviction_policy=EVICTION_POLICY) + tl.store(A_ptrs, a, mask=mask, eviction_policy=EVICTION_POLICY) + tl.store(B_ptrs, b, mask=mask, eviction_policy=EVICTION_POLICY) + + # tl.device_print("End locked add.") + tl.atomic_xchg(Lock_ptr, 0) + +@triton.jit +def _locked_add(Lock_ptr, Count_ptr, A_ptrs, a, B_ptrs, b, N_mask, NO_N_MASK, D_mask, NO_D_MASK: tl.constexpr, + EVICTION_POLICY: tl.constexpr=""): + # count = tl.load(Count_ptr, eviction_policy=EVICTION_POLICY) + if NO_D_MASK: + if NO_N_MASK: + tl.atomic_add(A_ptrs, a) + tl.atomic_add(B_ptrs, b) + else: + tl.atomic_add(A_ptrs, a, mask=N_mask[:, None]) + tl.atomic_add(B_ptrs, b, mask=N_mask[:, None]) + else: + mask = N_mask[:, None] & D_mask[None, :] + tl.atomic_add(A_ptrs, a, mask=mask) + tl.atomic_add(B_ptrs, b, mask=mask) + + +def get_configs(): + return [triton.Config({}, num_stages=s, num_warps=w) + # for mb in [64, 128] + # for nb in [16, 32, 64] + # for s in [8, 7, 6, 5, 4, 3, 2] + # for w in [4 , 2]] + # for mb in [32] + # for nb in [32] + for s in [8] + for w in [4]] + + + +@triton.autotune( + configs=get_configs(), + key=["token_size", "head_size"], + reset_to_zero=["DK_ptr", "DV_ptr", "KV_Lock_ptr", "KV_Count_ptr"] +) +@triton.jit +def _backward( + DO_ptr, + stride_doh: tl.constexpr, + stride_dom, + stride_dod: tl.constexpr, + DR_ptr, + stride_drh, + stride_drm, + A_ptr, + stride_ah, + stride_am, + Q_ptr, + stride_qh: tl.constexpr, + stride_qm, + stride_qd: tl.constexpr, + K_ptr, + stride_kh: tl.constexpr, + stride_kn, + stride_kd: tl.constexpr, + V_ptr, + stride_vh: tl.constexpr, + stride_vn, + stride_vd: tl.constexpr, + DQ_ptr, + stride_dqh: tl.constexpr, + stride_dqm, + stride_dqd: tl.constexpr, + DK_ptr, + stride_dkh: tl.constexpr, + stride_dkn, + stride_dkd: tl.constexpr, + DV_ptr, + stride_dvh: tl.constexpr, + stride_dvn, + stride_dvd: tl.constexpr, + KV_Lock_ptr, + KV_Count_ptr, + stride_kvs, + stride_kvh, + CSL_ptr, + logit_scale, + batch_size, + token_size, + head_size: tl.constexpr, + num_heads: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_CSL: tl.constexpr, + NO_D_MASK: tl.constexpr, + NO_M_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, + ALLOW_TF32: tl.constexpr, + inv_log2: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + acc_dtype: tl.constexpr = tl.float32, + attend_current: tl.constexpr = False +): + tl.static_assert(BLOCK_M % BLOCK_N == 0) + seq_id = tl.program_id(0) + fhead_id = tl.program_id(1) + seq_alloc_prog_id = tl.program_id(2) + num_seq_alloc_progs = tl.num_programs(2) + if seq_id == 0: + seq_start_offset = 0 + else: + seq_start_offset = tl.load(CSL_ptr + seq_id - 1).to(tl.int32) + seq_end_offset = tl.load(CSL_ptr + seq_id).to(tl.int32) + seq_length = seq_end_offset - seq_start_offset + num_seq_blocks = tl.cdiv(seq_length, BLOCK_M) + + seq_a_block_id = num_seq_blocks - seq_alloc_prog_id - 1 + seq_b_block_id = seq_alloc_prog_id - (num_seq_alloc_progs - num_seq_blocks) + + if seq_a_block_id >= 0 or seq_b_block_id >= 0: + # Universal stuff + qk_scale = inv_log2 * logit_scale + M_range = tl.arange(0, BLOCK_M) + N_range = tl.arange(0, BLOCK_N) + D_range = tl.arange(0, BLOCK_D) + D_mask = D_range < head_size + cm = tl.where(N_range[:, None] >= N_range[None, :], + 1.0, 0.0).to(Q_ptr.type.element_ty) + + if seq_a_block_id >= 0: + head_id = fhead_id * 2 + DO_head_seq_ptr = DO_ptr + stride_doh * head_id + stride_dom * seq_start_offset + DR_head_seq_ptr = DR_ptr + stride_drh * head_id + stride_drm * seq_start_offset + A_head_seq_ptr = A_ptr + stride_ah * head_id + stride_am * seq_start_offset + Q_head_seq_ptr = Q_ptr + stride_qh * head_id + stride_qm * seq_start_offset + K_head_seq_ptr = K_ptr + stride_kh * head_id + stride_kn * seq_start_offset + V_head_seq_ptr = V_ptr + stride_vh * head_id + stride_vn * seq_start_offset + DQ_head_seq_ptr = DQ_ptr + stride_dqh * head_id + stride_dqm * seq_start_offset + DK_head_seq_ptr = DK_ptr + stride_dkh * head_id + stride_dkn * seq_start_offset + DV_head_seq_ptr = DV_ptr + stride_dvh * head_id + stride_dvn * seq_start_offset + KV_Lock_head_seq_ptr = KV_Lock_ptr + stride_kvs * seq_id + stride_kvh * head_id + KV_Count_head_seq_ptr = KV_Count_ptr + \ + stride_kvs * seq_id + stride_kvh * head_id + _backward_one_row( + seq_a_block_id, + seq_length, + qk_scale, + M_range, + N_range, + D_range, + D_mask, + cm, + DO_head_seq_ptr, + stride_dom, + stride_dod, + DR_head_seq_ptr, + stride_drm, + A_head_seq_ptr, + stride_am, + Q_head_seq_ptr, + stride_qm, + stride_qd, + K_head_seq_ptr, + stride_kn, + stride_kd, + V_head_seq_ptr, + stride_vn, + stride_vd, + DQ_head_seq_ptr, + stride_dqm, + stride_dqd, + DK_head_seq_ptr, + stride_dkn, + stride_dkd, + DV_head_seq_ptr, + stride_dvn, + stride_dvd, + KV_Lock_head_seq_ptr, + KV_Count_head_seq_ptr, + logit_scale, + BLOCK_D, + NO_D_MASK, + NO_M_MASK, + ALLOW_TF32, + BLOCK_M, + BLOCK_N, + acc_dtype, + attend_current=attend_current + ) + if seq_b_block_id >= 0 and fhead_id * 2 + 1 < num_heads: + head_id = fhead_id * 2 + 1 + DO_head_seq_ptr = DO_ptr + stride_doh * head_id + stride_dom * seq_start_offset + DR_head_seq_ptr = DR_ptr + stride_drh * head_id + stride_drm * seq_start_offset + A_head_seq_ptr = A_ptr + stride_ah * head_id + stride_am * seq_start_offset + Q_head_seq_ptr = Q_ptr + stride_qh * head_id + stride_qm * seq_start_offset + K_head_seq_ptr = K_ptr + stride_kh * head_id + stride_kn * seq_start_offset + V_head_seq_ptr = V_ptr + stride_vh * head_id + stride_vn * seq_start_offset + DQ_head_seq_ptr = DQ_ptr + stride_dqh * head_id + stride_dqm * seq_start_offset + DK_head_seq_ptr = DK_ptr + stride_dkh * head_id + stride_dkn * seq_start_offset + DV_head_seq_ptr = DV_ptr + stride_dvh * head_id + stride_dvn * seq_start_offset + KV_Lock_head_seq_ptr = KV_Lock_ptr + stride_kvs * seq_id + stride_kvh * head_id + KV_Count_head_seq_ptr = KV_Count_ptr + \ + stride_kvs * seq_id + stride_kvh * head_id + _backward_one_row( + seq_b_block_id, + seq_length, + qk_scale, + M_range, + N_range, + D_range, + D_mask, + cm, + DO_head_seq_ptr, + stride_dom, + stride_dod, + DR_head_seq_ptr, + stride_drm, + A_head_seq_ptr, + stride_am, + Q_head_seq_ptr, + stride_qm, + stride_qd, + K_head_seq_ptr, + stride_kn, + stride_kd, + V_head_seq_ptr, + stride_vn, + stride_vd, + DQ_head_seq_ptr, + stride_dqm, + stride_dqd, + DK_head_seq_ptr, + stride_dkn, + stride_dkd, + DV_head_seq_ptr, + stride_dvn, + stride_dvd, + KV_Lock_head_seq_ptr, + KV_Count_head_seq_ptr, + logit_scale, + BLOCK_D, + NO_D_MASK, + NO_M_MASK, + ALLOW_TF32, + BLOCK_M, + BLOCK_N, + acc_dtype, + attend_current=attend_current + ) + + +@triton.jit +def _backward_one_row( + seq_prog_id, + seq_length, + qk_scale, + M_range, + N_range, + D_range, + D_mask, + cm, + DO_head_seq_ptr, + stride_dom, + stride_dod: tl.constexpr, + DR_head_seq_ptr, + stride_drm, + A_head_seq_ptr, + stride_am: tl.constexpr, + Q_head_seq_ptr, + stride_qm, + stride_qd: tl.constexpr, + K_head_seq_ptr, + stride_kn, + stride_kd: tl.constexpr, + V_head_seq_ptr, + stride_vn, + stride_vd: tl.constexpr, + DQ_head_seq_ptr, + stride_dqm, + stride_dqd: tl.constexpr, + DK_head_seq_ptr, + stride_dkn, + stride_dkd: tl.constexpr, + DV_head_seq_ptr, + stride_dvn, + stride_dvd: tl.constexpr, + KV_Lock_ptr, + KV_Count_ptr, + logit_scale, + BLOCK_D: tl.constexpr, + NO_D_MASK: tl.constexpr, + NO_M_MASK: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + acc_dtype: tl.constexpr = tl.float32, + is_compiling: tl.constexpr = False, + attend_current: tl.constexpr = False, +): + # Loading thread information + block_start_offset = BLOCK_M * seq_prog_id + M_blk_idxs = block_start_offset + M_range + M_mask = M_blk_idxs < seq_length + NO_M_MASK = (block_start_offset + BLOCK_M - 1) < seq_length + + N_blk_idxs_start = 0 + N_blk_idxs = N_blk_idxs_start + N_range + + # Init pointers + # Inputs + DO_blk_ptrs = DO_head_seq_ptr + \ + (stride_dom * M_blk_idxs[:, None] + stride_dod * D_range[None, :]) + + K_blk_ptrs = K_head_seq_ptr + \ + (stride_kn * N_blk_idxs[:, None] + stride_kd * D_range[None, :]) + Q_blk_ptrs = Q_head_seq_ptr + \ + (stride_qm * M_blk_idxs[:, None] + stride_qd * D_range[None, :]) + V_blk_ptrs = V_head_seq_ptr + \ + (stride_vn * N_blk_idxs[:, None] + stride_vd * D_range[None, :]) + A_blk_ptrs = A_head_seq_ptr + stride_am * M_blk_idxs + # Outputs + DQ_blk_ptrs = DQ_head_seq_ptr + \ + (stride_dqm * M_blk_idxs[:, None] + stride_dqd * D_range[None, :]) + DK_blk_ptrs = DK_head_seq_ptr + \ + (stride_dkn * N_blk_idxs[:, None] + stride_dkd * D_range[None, :]) + DV_blk_ptrs = DV_head_seq_ptr + \ + (stride_dvn * N_blk_idxs[:, None] + stride_dvd * D_range[None, :]) + DR_blk_ptrs = DR_head_seq_ptr + stride_drm * M_blk_idxs + + # --- Load band vectors --- + if NO_D_MASK: + if NO_M_MASK: + q = tl.load(Q_blk_ptrs) + do = tl.load(DO_blk_ptrs) + dr = tl.load(DR_blk_ptrs) + neg_log_acc = tl.load(A_blk_ptrs, mask=M_mask) + else: + q = tl.load(Q_blk_ptrs, mask=M_mask[:, None]) + do = tl.load(DO_blk_ptrs, mask=M_mask[:, None]) + dr = tl.load(DR_blk_ptrs, mask=M_mask) + neg_log_acc = tl.load(A_blk_ptrs, mask=M_mask) + else: + MD_mask = M_mask[:, None] & D_mask[None, :] + q = tl.load(Q_blk_ptrs, mask=MD_mask) + do = tl.load(DO_blk_ptrs, mask=MD_mask) + dr = tl.load(DR_blk_ptrs, mask=M_mask) + neg_log_acc = tl.load(A_blk_ptrs, mask=M_mask) + # --- End band vectors --- + + # Init accumulators + neg_log_acc = neg_log_acc.to(dtype=acc_dtype) + grad_prev_acc = tl.zeros((BLOCK_M,), dtype=acc_dtype) + dq = tl.zeros((BLOCK_M, BLOCK_D), dtype=acc_dtype) + + fwd_cm = tl.trans(cm) + # always multiple of number of blocks. + iters = (block_start_offset + BLOCK_M) // BLOCK_N + # if (last_N_blk_idxs_end - sequence_start_offset) % BLOCK_N > 0: + # tl.device_print('remainder') + # Iterate only up to start of sequence + for i in range(iters): + on_band = (iters - i - 1) < BLOCK_M // BLOCK_N + N_mask = N_blk_idxs < seq_length + NO_N_MASK = (N_blk_idxs_start + BLOCK_N - 1) < seq_length + # --- Recompute block --- + k, v = load_kv( + K_blk_ptrs, + V_blk_ptrs, + N_mask=N_mask, + NO_N_MASK=(N_blk_idxs_start + BLOCK_N - 1) < seq_length, + # N_mask=N_mask, NO_N_MASK=False, + D_mask=D_mask, + NO_D_MASK=NO_D_MASK, + ) + p, log_om_beta, neg_log_acc = compute_block( + q, + k, + qk_scale, + neg_log_acc, + M_blk_idxs, + N_blk_idxs, + cm, + on_band, + ALLOW_TF32, + attend_current=attend_current, + backward=True, + is_compiling=is_compiling, + ) + + if not NO_M_MASK: + neg_log_acc = tl.where(M_mask, neg_log_acc, 0.0) + + # --- Do gradient stuff --- + att_dA = p * \ + (tl.dot(do, tl.trans(v), allow_tf32=ALLOW_TF32) - dr[:, None]) + cumul_att_dA = ( + tl.dot(att_dA.to(cm.dtype), fwd_cm, + allow_tf32=ALLOW_TF32) + grad_prev_acc[:, None] + ) # 180 -> 174 + # cumul_att_dA = tl.cumsum(att_dA, axis=1) + grad_prev_acc[:, None] # 180 -> 174 + grad_prev_acc += tl.sum(att_dA, axis=1) + beta = 1 - tl.exp2(log_om_beta) # 180 -> 175 + dqk = att_dA - beta * cumul_att_dA + + dq = tl.dot(dqk.to(k.dtype), k, acc=dq, allow_tf32=ALLOW_TF32) + block_dk = tl.dot(tl.trans(dqk).to(q.dtype), q, + allow_tf32=ALLOW_TF32) * logit_scale + block_dv = tl.dot(tl.trans(p), do.to(p.dtype), allow_tf32=ALLOW_TF32) + + locked_add( + KV_Lock_ptr + i, + KV_Count_ptr + i, + DK_blk_ptrs, + block_dk, + DV_blk_ptrs, + block_dv, + N_mask, + NO_N_MASK, + D_mask, + NO_D_MASK, + ) + + # --- End gradient stuff --- + N_blk_idxs += BLOCK_N + N_blk_idxs_start += BLOCK_N + K_blk_ptrs += BLOCK_N * stride_kn + V_blk_ptrs += BLOCK_N * stride_vn + DK_blk_ptrs += BLOCK_N * stride_dkn + DV_blk_ptrs += BLOCK_N * stride_dvn + + dq = (logit_scale * dq).to(DQ_head_seq_ptr.type.element_ty) + + if NO_D_MASK: + tl.store(DQ_blk_ptrs, dq, mask=M_mask[:, None]) + else: + tl.store(DQ_blk_ptrs, dq, mask=M_mask[:, None] & D_mask[None, :]) + + +def varlen_bwd( + do: torch.Tensor, + dr: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlens: int, + neg_log_acc: torch.Tensor, + logit_scale, + attend_current=False, + BLOCK_M=64, + BLOCK_N=32, +): + batch_size = cu_seqlens.size(0) + num_heads, token_size, dim_size = q.size() + if logit_scale is None: + logit_scale = 1 / math.sqrt(dim_size) + N_count = triton.cdiv(token_size, BLOCK_N) + + # dqdkdv = torch.zeros((token_size, num_heads, 3 * dim_size), device=do.device, dtype=do.dtype) + # dqdkdv = dqdkdv.permute(1, 0, 2) + # dq, dk, dv = dqdkdv.chunk(3, dim=-1) + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + + num_sequences = batch_size + num_folded_heads = triton.cdiv(num_heads, 2) + num_seq_blocks = triton.cdiv(max_seqlens, BLOCK_M) + 1 + _compileable_backward( + do, + dr, + q, + k, + v, + cu_seqlens, + neg_log_acc, + logit_scale, + BLOCK_M, + BLOCK_N, + batch_size, + num_heads, + token_size, + dim_size, + dq, + dk, + dv, + # dkdv_lock, + # dkdv_count, + num_sequences, + num_folded_heads, + num_seq_blocks, + attend_current=attend_current + ) + return dq, dk, dv + + +@custom_op("varlen_bwd", mutates_args={"dq", "dk", "dv"}) +def _compileable_backward( + do: torch.Tensor, + dr: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, + neg_log_acc: torch.Tensor, + logit_scale: float, + BLOCK_M: int, + BLOCK_N: int, + batch_size: int, + num_heads: int, + token_size: int, + dim_size: int, + dq: torch.Tensor, + dk: torch.Tensor, + dv: torch.Tensor, + # dkdv_lock: torch.Tensor, + # dkdv_count: torch.Tensor, + num_sequences: int, + num_folded_heads: int, + num_seq_blocks: int, + attend_current: bool = False, +) -> None: + BLOCK_D = triton.next_power_of_2(dim_size) + N_count = num_seq_blocks * (BLOCK_M // BLOCK_N) + dkdv_lock = torch.zeros( + (num_sequences, num_heads, N_count), dtype=torch.int32, device=q.device) + dkdv_count = torch.zeros( + (num_sequences, num_heads, N_count), dtype=torch.int32, device=q.device) + + _backward[num_sequences, num_folded_heads, num_seq_blocks]( + # DO_ptr, stride_doh, stride_dom, stride_dod, + do, + do.stride(0), + do.stride(1), + do.stride(2), + # DR_ptr, stride_drh, stride_drm, + dr, + dr.stride(0), + dr.stride(1), + # A_ptr, stride_ah, stride_am, + neg_log_acc, + neg_log_acc.stride(0), + neg_log_acc.stride(1), + # Q_ptr, stride_qh, stride_qm, stride_qd, + q, + q.stride(0), + q.stride(1), + q.stride(2), + # K_ptr, stride_kh, stride_kn, stride_kd, + k, + k.stride(0), + k.stride(1), + k.stride(2), + # V_ptr, stride_vh, stride_vn, stride_vd, + v, + v.stride(0), + v.stride(1), + v.stride(2), + # DQ_ptr, stride_dqh, stride_dqm, stride_dqd, + dq, + dq.stride(0), + dq.stride(1), + dq.stride(2), + # DK_ptr, stride_dkh, stride_dkn, stride_dkd, + dk, + dk.stride(0), + dk.stride(1), + dk.stride(2), + # DV_ptr, stride_dvh, stride_dvn, stride_dvd, + dv, + dv.stride(0), + dv.stride(1), + dv.stride(2), + # KV_Lock_ptr, KV_Count_ptr, stride_kvl, + dkdv_lock, + dkdv_count, + dkdv_lock.stride(0), + dkdv_lock.stride(1), + cu_seqlens, + logit_scale=logit_scale, + batch_size=batch_size, + token_size=token_size, + head_size=dim_size, + num_heads=num_heads, + # BLOCK_M=BLOCK_M, + # BLOCK_N=BLOCK_N, + BLOCK_D=BLOCK_D, + BLOCK_CSL=triton.next_power_of_2(batch_size), + NO_D_MASK=BLOCK_D == dim_size, + NO_M_MASK=False, + NO_N_MASK=False, + ALLOW_TF32=ALLOW_TF32, + inv_log2=inv_log2, + attend_current=attend_current, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) diff --git a/sba_code/stickbreaking_attention/sb_varlen/sb_varlen_fwd.py b/sba_code/stickbreaking_attention/sb_varlen/sb_varlen_fwd.py new file mode 100644 index 000000000..eceb7b78b --- /dev/null +++ b/sba_code/stickbreaking_attention/sb_varlen/sb_varlen_fwd.py @@ -0,0 +1,522 @@ +import torch +import triton +import triton.language as tl + +from ..utils import ALLOW_TF32, inv_log2 +from .softplus import softplus +from ..utils import custom_op + + +@triton.jit +def load_kv(K_blk_ptrs, V_blk_ptrs, N_mask, NO_N_MASK, D_mask, NO_D_MASK: tl.constexpr): + if NO_D_MASK: + if NO_N_MASK: + k = tl.load(K_blk_ptrs) + v = tl.load(V_blk_ptrs) + else: + k = tl.load(K_blk_ptrs, mask=N_mask[:, None]) + v = tl.load(V_blk_ptrs, mask=N_mask[:, None]) + else: + mask = N_mask[:, None] & D_mask[None, :] + k = tl.load(K_blk_ptrs, mask=mask) + v = tl.load(V_blk_ptrs, mask=mask) + return k, v + + +@triton.jit +def compute_block( + q, + k, + qk_scale, + neg_log_acc, + M_blk_idxs, + N_blk_idxs, + cm, + on_band: tl.constexpr, + ALLOW_TF32: tl.constexpr, + backward: tl.constexpr, + attend_current: tl.constexpr = False, + use_cumsum: tl.constexpr = False, + is_compiling: tl.constexpr = False, +): + qk = tl.dot(q, tl.trans(k), allow_tf32=ALLOW_TF32) * qk_scale + + # log_om_beta (one minus beta) : log(1 - \beta) + log_om_beta = -softplus(qk, is_compiling=is_compiling) + + if on_band: # diagonal + if attend_current: + block_mask = M_blk_idxs[:, None] >= N_blk_idxs[None, :] + else: + block_mask = M_blk_idxs[:, None] > N_blk_idxs[None, :] + log_om_beta = tl.where(block_mask, log_om_beta, 0.0) + if backward: + neg_log_acc -= tl.sum(log_om_beta, axis=1) + log_p = qk + neg_log_acc[:, None] + + if use_cumsum: + log_p += tl.cumsum(log_om_beta.to(q.dtype), axis=1, reverse=True) + else: + log_p = tl.dot(log_om_beta.to(q.dtype), cm, + acc=log_p, allow_tf32=ALLOW_TF32) + + p = tl.math.exp2(log_p) + p = tl.where(block_mask, p, 0.0) + else: + if backward: + neg_log_acc -= tl.sum(log_om_beta, axis=1) + log_p = qk + neg_log_acc[:, None] + if use_cumsum: + log_p += tl.cumsum(log_om_beta.to(q.dtype), axis=1, reverse=True) + else: + log_p = tl.dot(log_om_beta.to(q.dtype), cm, + acc=log_p, allow_tf32=ALLOW_TF32) + + p = tl.math.exp2(log_p) + if not backward: + neg_log_acc += tl.sum(log_om_beta, axis=1) + return p, log_om_beta, neg_log_acc + + +@triton.jit +def _forward_one_row( + seq_block_id, + seq_length, + qk_scale, + M_range, + N_range, + D_range, + D_mask, + cm, + Q_head_seq_ptr, + stride_qm, + stride_qd: tl.constexpr, + K_head_seq_ptr, + stride_kn, + stride_kd: tl.constexpr, + V_head_seq_ptr, + stride_vn, + stride_vd: tl.constexpr, + O_head_seq_ptr, + stride_om, + stride_od: tl.constexpr, + R_head_seq_ptr, + stride_rm, + A_head_seq_ptr, + stride_am, + W_head_seq_ptr, + stride_wm, + stride_wn, + BLOCK_D: tl.constexpr, + NO_D_MASK: tl.constexpr, + NO_M_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + no_grad: tl.constexpr = False, + acc_dtype: tl.constexpr = tl.float32, + return_attention: tl.constexpr = False, + is_compiling: tl.constexpr = False, + use_cumsum: tl.constexpr = False, + attend_current: tl.constexpr = False, +): + # Loading thread information + block_start_offset = BLOCK_M * seq_block_id + M_blk_idxs = block_start_offset + M_range + M_mask = M_blk_idxs < seq_length + NO_M_MASK = (block_start_offset + BLOCK_M - 1) < seq_length + + # BLOCK_M must be a multiple of BLOCK_N + N_blk_idxs_start = block_start_offset + BLOCK_M + N_blk_idxs = N_blk_idxs_start + N_range + + # Init pointers + Q_blk_ptrs = Q_head_seq_ptr + \ + (stride_qm * M_blk_idxs[:, None] + stride_qd * D_range[None, :]) + K_blk_ptrs = K_head_seq_ptr + \ + (stride_kn * N_blk_idxs[:, None] + stride_kd * D_range[None, :]) + V_blk_ptrs = V_head_seq_ptr + \ + (stride_vn * N_blk_idxs[:, None] + stride_vd * D_range[None, :]) + O_blk_ptrs = O_head_seq_ptr + \ + (stride_om * M_blk_idxs[:, None] + stride_od * D_range[None, :]) + R_blk_ptrs = R_head_seq_ptr + stride_rm * M_blk_idxs + A_blk_ptrs = A_head_seq_ptr + stride_am * M_blk_idxs + + # --- Load band vectors --- + if NO_D_MASK: + if NO_M_MASK: + q = tl.load(Q_blk_ptrs) + else: + q = tl.load(Q_blk_ptrs, mask=M_mask[:, None], other=0.0) + else: + q = tl.load( + Q_blk_ptrs, mask=M_mask[:, None] & D_mask[None, :], other=0.0) + + iters = N_blk_idxs_start // BLOCK_N + neg_log_acc = tl.zeros([BLOCK_M], dtype=acc_dtype) + acc = tl.zeros([BLOCK_M, BLOCK_D], dtype=acc_dtype) + # --- End band vectors --- + + # Iterate only up to start of sequence + for i in range(iters): + N_blk_idxs -= BLOCK_N + N_blk_idxs_start -= BLOCK_N + K_blk_ptrs -= BLOCK_N * stride_kn + V_blk_ptrs -= BLOCK_N * stride_vn + + N_mask = N_blk_idxs < seq_length + k, v = load_kv( + K_blk_ptrs, + V_blk_ptrs, + N_mask=N_mask, + NO_N_MASK=N_blk_idxs_start + BLOCK_N - 1 < seq_length, + D_mask=D_mask, + NO_D_MASK=NO_D_MASK, + ) + on_band = i < BLOCK_M // BLOCK_N + p, _, neg_log_acc = compute_block( + q, + k, + qk_scale, + neg_log_acc, + M_blk_idxs, + N_blk_idxs, + cm, + on_band, + ALLOW_TF32, + attend_current=attend_current, + backward=False, + is_compiling=is_compiling, + use_cumsum=use_cumsum, + ) + # Store intermediate values + acc = tl.dot(p.to(v.dtype), v, acc, allow_tf32=ALLOW_TF32) + if return_attention: # TODO write returns_attention_weight + tl.store( + W_head_seq_ptr + stride_wm * + M_blk_idxs[:, None] + stride_wn * N_blk_idxs[None, :], + p, + mask=(M_blk_idxs < seq_length)[:, None] & ( + N_blk_idxs < seq_length)[None, :], + ) + if NO_M_MASK: + tl.store(R_blk_ptrs, tl.math.exp2(neg_log_acc)) + tl.store(A_blk_ptrs, neg_log_acc.to(A_head_seq_ptr.type.element_ty)) + else: + tl.store(R_blk_ptrs, tl.math.exp2(neg_log_acc), mask=M_mask) + tl.store(A_blk_ptrs, neg_log_acc.to( + A_head_seq_ptr.type.element_ty), mask=M_mask) + if NO_D_MASK: + tl.store(O_blk_ptrs, acc.to( + O_head_seq_ptr.type.element_ty), mask=M_mask[:, None]) + else: + tl.store(O_blk_ptrs, acc.to(O_head_seq_ptr.type.element_ty), + mask=M_mask[:, None] & D_mask[None, :]) + + +def get_configs(): + return [triton.Config({}, num_stages=s, num_warps=w) + # for mb in [64, 128] + # for nb in [16, 32, 64] + for s in [4] # , 2, 3, 5, 6, 7, 8] + for w in [4]] # , 2]] + # for mb in [64] + # for nb in [32] + # for s in [4] + # for w in [4]] + + + +@triton.autotune(configs=get_configs(), key=["head_size"]) +@triton.jit +def _forward( + Q_ptr, + stride_qh: tl.constexpr, + stride_qm, + stride_qd: tl.constexpr, + K_ptr, + stride_kh: tl.constexpr, + stride_kn, + stride_kd: tl.constexpr, + V_ptr, + stride_vh: tl.constexpr, + stride_vn, + stride_vd: tl.constexpr, + O_ptr, + stride_oh: tl.constexpr, + stride_om, + stride_od: tl.constexpr, + R_ptr, + stride_rh, + stride_rm: tl.constexpr, + A_ptr, + stride_ah, + stride_am: tl.constexpr, + W_ptr, + stride_wh, + stride_wm, + stride_wn, + CSL_ptr, + logit_scale: tl.constexpr, + batch_size, + token_size, + head_size: tl.constexpr, + num_heads: tl.constexpr, + BLOCK_D: tl.constexpr, + NO_D_MASK: tl.constexpr, + NO_M_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, + ALLOW_TF32: tl.constexpr, + inv_log2: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + no_grad: tl.constexpr = False, + acc_dtype: tl.constexpr = tl.float32, + return_attention: tl.constexpr = False, + use_cumsum: tl.constexpr = False, + attend_current: tl.constexpr = False +): + tl.static_assert(BLOCK_M % BLOCK_N == 0) + seq_id = tl.program_id(0) + fhead_id = tl.program_id(1) + seq_alloc_prog_id = tl.program_id(2) + num_seq_alloc_progs = tl.num_programs(2) + if seq_id == 0: + seq_start_offset = 0 + else: + seq_start_offset = tl.load(CSL_ptr + seq_id - 1).to(tl.int32) + seq_end_offset = tl.load(CSL_ptr + seq_id).to(tl.int32) + seq_length = seq_end_offset - seq_start_offset + num_seq_blocks = tl.cdiv(seq_length, BLOCK_M) + + seq_a_block_id = num_seq_blocks - seq_alloc_prog_id - 1 + seq_b_block_id = seq_alloc_prog_id - (num_seq_alloc_progs - num_seq_blocks) + + if seq_a_block_id >= 0 or seq_b_block_id >= 0: + # Universal stuff + qk_scale = inv_log2 * logit_scale + M_range = tl.arange(0, BLOCK_M) + N_range = tl.arange(0, BLOCK_N) + D_range = tl.arange(0, BLOCK_D) + D_mask = D_range < head_size + if not use_cumsum: + cm = tl.where(N_range[:, None] >= N_range[None, :], 1.0, 0.0).to( + Q_ptr.type.element_ty) + else: + cm = None + + if seq_a_block_id >= 0: + # First head block + head_id = fhead_id * 2 + Q_head_seq_ptr = Q_ptr + stride_qh * head_id + stride_qm * seq_start_offset + K_head_seq_ptr = K_ptr + stride_kh * head_id + stride_kn * seq_start_offset + V_head_seq_ptr = V_ptr + stride_vh * head_id + stride_vn * seq_start_offset + O_head_seq_ptr = O_ptr + stride_oh * head_id + stride_om * seq_start_offset + R_head_seq_ptr = R_ptr + stride_rh * head_id + stride_rm * seq_start_offset + A_head_seq_ptr = A_ptr + stride_ah * head_id + stride_am * seq_start_offset + W_head_seq_ptr = W_ptr + stride_wh * head_id + stride_am * seq_start_offset + _forward_one_row( + seq_a_block_id, + seq_length, + qk_scale, + M_range, + N_range, + D_range, + D_mask, + cm, + Q_head_seq_ptr, + stride_qm, + stride_qd, + K_head_seq_ptr, + stride_kn, + stride_kd, + V_head_seq_ptr, + stride_vn, + stride_vd, + O_head_seq_ptr, + stride_om, + stride_od, + R_head_seq_ptr, + stride_rm, + A_head_seq_ptr, + stride_am, + W_head_seq_ptr, + stride_wm, + stride_wn, + BLOCK_D, + NO_D_MASK, + NO_M_MASK, + NO_N_MASK, + ALLOW_TF32, + BLOCK_M, + BLOCK_N, + no_grad, + acc_dtype, + return_attention, + use_cumsum=use_cumsum, + attend_current=attend_current + ) + if seq_b_block_id >= 0 and fhead_id * 2 + 1 < num_heads: + # Reverse head block + head_id = fhead_id * 2 + 1 + Q_head_seq_ptr = Q_ptr + stride_qh * head_id + stride_qm * seq_start_offset + K_head_seq_ptr = K_ptr + stride_kh * head_id + stride_kn * seq_start_offset + V_head_seq_ptr = V_ptr + stride_vh * head_id + stride_vn * seq_start_offset + O_head_seq_ptr = O_ptr + stride_oh * head_id + stride_om * seq_start_offset + R_head_seq_ptr = R_ptr + stride_rh * head_id + stride_rm * seq_start_offset + A_head_seq_ptr = A_ptr + stride_ah * head_id + stride_am * seq_start_offset + W_head_seq_ptr = W_ptr + stride_wh * head_id + stride_am * seq_start_offset + _forward_one_row( + seq_b_block_id, + seq_length, + qk_scale, + M_range, + N_range, + D_range, + D_mask, + cm, + Q_head_seq_ptr, + stride_qm, + stride_qd, + K_head_seq_ptr, + stride_kn, + stride_kd, + V_head_seq_ptr, + stride_vn, + stride_vd, + O_head_seq_ptr, + stride_om, + stride_od, + R_head_seq_ptr, + stride_rm, + A_head_seq_ptr, + stride_am, + W_head_seq_ptr, + stride_wm, + stride_wn, + BLOCK_D, + NO_D_MASK, + NO_M_MASK, + NO_N_MASK, + ALLOW_TF32, + BLOCK_M, + BLOCK_N, + no_grad, + acc_dtype, + return_attention, + use_cumsum=use_cumsum, + attend_current=attend_current + ) + + +def varlen_fwd( + q, k, v, cu_seqlens, max_seqlens, logit_scale, attend_current=False, no_grad=False, return_attention=False, BLOCK_M=64, BLOCK_N=32 +): + batch_size = cu_seqlens.size(0) + num_heads, token_size, dim_size = q.size() + o = torch.empty_like(q) + rem = torch.zeros_like(q[:, :, 0], device=q.device) + neg_log_acc = torch.zeros_like(rem, device=q.device, dtype=torch.float32) + if return_attention: + W = torch.full((num_heads, token_size, token_size), 0.0, + dtype=torch.float32, device=q.device) + else: + W = torch.empty((1, 1, 1), device=q.device) + + _compileable_forward( + q, + k, + v, + cu_seqlens, + max_seqlens, + logit_scale, + no_grad, + return_attention, + BLOCK_M, + BLOCK_N, + num_heads, + batch_size, + token_size, + dim_size, + o, + rem, + neg_log_acc, + W, + attend_current=attend_current + ) + if return_attention: + return o, rem, neg_log_acc, W + else: + return o, rem, neg_log_acc + + +@custom_op("varlen_fwd", mutates_args={"o", "rem", "neg_log_acc", "W"}) +def _compileable_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlens: int, + logit_scale: float, + no_grad: bool, + return_attention: bool, + BLOCK_M: int, + BLOCK_N: int, + num_heads: int, + batch_size: int, + token_size: int, + dim_size: int, + o: torch.Tensor, + rem: torch.Tensor, + neg_log_acc: torch.Tensor, + W: torch.Tensor, + attend_current: bool, +) -> None: + num_sequences = batch_size + num_folded_heads = triton.cdiv(num_heads, 2) + num_seq_blocks = triton.cdiv(max_seqlens, BLOCK_M) + 1 + BLOCK_D = triton.next_power_of_2(dim_size) + grid = (num_sequences, num_folded_heads, num_seq_blocks) + q_stride = q.stride() + k_stride = k.stride() + v_stride = v.stride() + o_stride = o.stride() + + _forward[grid]( + q, q_stride[0], q_stride[1], q_stride[2], + k, k_stride[0], k_stride[1], k_stride[2], + v, v_stride[0], v_stride[1], v_stride[2], + o, o_stride[0], o_stride[1], o_stride[2], + rem, + rem.stride(0), + rem.stride(1), + neg_log_acc, + neg_log_acc.stride(0), + neg_log_acc.stride(1), + W, + W.stride(0), + W.stride(1), + W.stride(2), + cu_seqlens, + # pid_debug, + logit_scale=logit_scale, + batch_size=batch_size, + token_size=token_size, + head_size=dim_size, + num_heads=num_heads, + no_grad=no_grad, + BLOCK_D=BLOCK_D, + NO_D_MASK=BLOCK_D == dim_size, + NO_M_MASK=False, + NO_N_MASK=False, + # BLOCK_M=BLOCK_M, + # BLOCK_N=BLOCK_N, + ALLOW_TF32=ALLOW_TF32, + inv_log2=inv_log2, + return_attention=return_attention, + acc_dtype=tl.float32, + use_cumsum=False, + attend_current=attend_current, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) diff --git a/sba_code/stickbreaking_attention/sb_varlen/softplus.py b/sba_code/stickbreaking_attention/sb_varlen/softplus.py new file mode 100644 index 000000000..50e60a11a --- /dev/null +++ b/sba_code/stickbreaking_attention/sb_varlen/softplus.py @@ -0,0 +1,52 @@ +import torch +import triton +from triton import language as tl + + +def _generate_asm(num_pack): + template = """ + .reg .pred p; + setp.gt.f32 p, ${in_reg}, 15.; + @p mov.f32 ${out_reg}, ${in_reg}; + @!p ex2.approx.ftz.f32 ${out_reg}, ${in_reg}; + @!p add.f32 ${out_reg}, ${out_reg}, 1.0; + @!p lg2.approx.ftz.f32 ${out_reg}, ${out_reg}; + """ + out_str = "" + + for i in range(num_pack): + inner_str = template.format(out_reg=i, in_reg=i + num_pack) + out_str += "{" + inner_str + "}\n" + # flatten out because torch.compile doesn't like newlines + out_str = " ".join(out_str.split("\n")) + return out_str + + +def _generate_constraints(num_pack): + return ",".join("=r" for i in range(num_pack)) + "," + ",".join("r" for i in range(num_pack)) + + +_NUM_REG = 1 +asm_str: tl.constexpr = tl.constexpr(_generate_asm(_NUM_REG)) +constraints_str: tl.constexpr = tl.constexpr(_generate_constraints(_NUM_REG)) +NUM_REG: tl.constexpr = tl.constexpr(_NUM_REG) + + +@triton.jit +def softplus(x, is_compiling: tl.constexpr = False): + if is_compiling: + tl.static_print("Using triton softplus.") + out = tl.where(x < 15.0, tl.math.log2(1 + tl.math.exp2(x)), x) + return out + else: + out = tl.inline_asm_elementwise( + asm=asm_str, + constraints=constraints_str, + pack=NUM_REG, + args=[ + x, + ], + dtype=tl.float32, + is_pure=True, + ) + return out diff --git a/sba_code/stickbreaking_attention/utils.py b/sba_code/stickbreaking_attention/utils.py new file mode 100644 index 000000000..4810b2542 --- /dev/null +++ b/sba_code/stickbreaking_attention/utils.py @@ -0,0 +1,39 @@ +import torch +from typing import Callable, Iterable, Sequence +import math + +PACKAGE_NAME = "stickbreaking_attention" +log2 = math.log(2) +inv_log2 = 1 / log2 +ALLOW_TF32 = True + + + +def _dispatch(func: Callable, compileable_fn: Callable, *args, **kwargs): + if torch.compiler.is_compiling(): + output = compileable_fn(*args, **kwargs) + else: + output = func(*args, **kwargs) + return output + + +def custom_op( + name: str = None, + mutates_args: str | Iterable[str] = None, + device_types: str | Sequence[str] | None = None, + schema: str | None = None, +) -> Callable: + compileable_name = f"{PACKAGE_NAME}::{name}" + + def _inner(func: Callable): + compileable_func = torch.library.custom_op( + compileable_name, func, mutates_args=mutates_args, device_types=device_types, schema=schema + ) + + def _run(*args, **kwargs): + return _dispatch(func, compileable_func, *args, **kwargs) + # _run.__signature__ = inspect.signature(func) + # _run.__name__ = func.__name__ + return _run + + return _inner diff --git a/sba_code/tests/__init__.py b/sba_code/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sba_code/tests/test_attn.py b/sba_code/tests/test_attn.py new file mode 100644 index 000000000..41b448c8a --- /dev/null +++ b/sba_code/tests/test_attn.py @@ -0,0 +1,74 @@ +import torch +import pytest +import math +from stickbreaking_attention.sb_attn import sb_attn +from transformers import set_seed +from stickbreaking_attention.sb_ref import stickbreaking +from .test_varlen import assert_close + + +def ref_fwd(q, k, v, length, attend_current=False): + cm = torch.ones(length, length).tril(-1).to(q) + if attend_current: + mask = torch.ones(length, length).triu(1).cuda().bool() + else: + mask = torch.ones(length, length).triu(0).cuda().bool() + o, rem = stickbreaking(q, k, v, mask, cm) + o = o + rem[..., None] * v + return o + +def ref_fwdbwd(do, q, k, v, length, attend_current=False): + q.requires_grad = True + k.requires_grad = True + v.requires_grad = True + output = ref_fwd(q, k, v, length, attend_current) + output.backward(do) + dq = q.grad + dk = k.grad + dv = v.grad + q.grad = None + k.grad = None + v.grad = None + return output, dq, dk, dv + + +class TestClass: + + @pytest.mark.parametrize('batch_size', [4, 2, 1]) + @pytest.mark.parametrize('num_heads', [24, 8, 4, 2, 1, 7]) + @pytest.mark.parametrize('head_dim', [64, 32, 16, 50]) + @pytest.mark.parametrize('length', [4096, 2048, 1024, 512, 256, 500]) + @pytest.mark.parametrize('dtype', [torch.bfloat16]) + @pytest.mark.parametrize('forward_only', [False]) + @pytest.mark.parametrize('attend_current', [False, True]) + def test_varlen(self, batch_size, num_heads, head_dim, attend_current, length, dtype, forward_only): + set_seed(1337) + torch.set_printoptions(linewidth=110, edgeitems=30) + device = torch.device('cuda:0') + input_dims = (batch_size, num_heads, length, head_dim) + v = 0.25 * torch.randn(input_dims, device=device, dtype=torch.float32) + q = 0.25 * (torch.randn(input_dims, device=device, dtype=torch.float32) + 1) + k = 0.25 * (torch.randn(input_dims, device=device, dtype=torch.float32) - 1) + print(q.max(), k.max(), v.max()) + q = q.to(dtype).requires_grad_() + k = k.to(dtype).requires_grad_() + v = v.to(dtype).requires_grad_() + do = torch.randn(input_dims, device=device, dtype=dtype) + + with torch.cuda.device(device): + o, rem= sb_attn( + q, k, v, + inv_temp=1 / math.sqrt(q.size(-1)), + attend_current=attend_current + ) + o = o + rem[..., None] * v + ref_out, ref_dq, ref_dk, ref_dv = ref_fwdbwd(do, q, k, v, length, + attend_current=attend_current) + eps = 0.05 + torch.cuda.synchronize() + assert_close("o", ref_out, o, eps) + if not forward_only: + dq, dk, dv = torch.autograd.grad(o, inputs=(q, k, v), grad_outputs=do) + assert_close("dq", ref_dq, dq, eps) + assert_close("dk", ref_dk, dk, eps) + assert_close("dv", ref_dv, dv, eps) diff --git a/sba_code/tests/test_varlen.py b/sba_code/tests/test_varlen.py new file mode 100644 index 000000000..12b5073e6 --- /dev/null +++ b/sba_code/tests/test_varlen.py @@ -0,0 +1,110 @@ +import torch +import pytest +import math +from torch.nn import functional as F +from stickbreaking_attention.sb_varlen import sb_attn_varlen +from transformers import set_seed +from stickbreaking_attention.sb_ref import stickbreaking + + +def ref_fwd(q, k, v, lengths, attend_current=False): + splits = list(lengths.cpu().numpy()) + max_len = max(splits) + cm = torch.ones(max_len, max_len).tril(-1).to(q) + mask = torch.ones(max_len, max_len).triu(0 if not attend_current else 1).cuda().bool() + outputs = [] + for q_chunk, k_chunk, v_chunk in zip(q.split(splits, 1), k.split(splits, 1), v.split(splits, 1)): + len = q_chunk.size(1) + o, rem = stickbreaking( + q_chunk[None, :], + k_chunk[None, :], + v_chunk[None, :], + mask[:len, :len], cm[:len, :len] + ) + + o = o + rem[..., None] * v_chunk[None] + outputs.append(o[0]) + return torch.cat(outputs, 1) + +def ref_bwd(do, q, k, v, lengths, attend_current=False): + q.requires_grad = True + k.requires_grad = True + v.requires_grad = True + output = ref_fwd(q, k, v, lengths, attend_current=attend_current) + output.backward(do) + dq = q.grad + dk = k.grad + dv = v.grad + q.grad = None + k.grad = None + v.grad = None + return output, dq, dk, dv + +def assert_close(varname, a, b, eps): + if torch.isnan(a).any(): + print("Reference is nan") + return + assert not torch.isnan(b).any() + diff = (a - b).abs() + + max_diff= diff.max() + if max_diff < eps: + print(varname, max_diff.item()) + else: + print(varname, max_diff.item(), diff.median().item()) + print((diff.sum(0).median(dim=0)[0] > eps).int()) + err_locs = (diff.sum(0).median(dim=1)[0] > eps).int() + print(err_locs, err_locs.sum()) + assert max_diff < eps, max_diff + + + +class TestClass: + + # @pytest.mark.parametrize('batch_size', [4, 2, 1]) + # @pytest.mark.parametrize('num_heads', [24, 8, 4, 2, 1, 7]) + # @pytest.mark.parametrize('head_dim', [64, 32, 16, 50]) + # @pytest.mark.parametrize('length', [4096, 2048, 1024, 512, 256, 500]) + @pytest.mark.parametrize('batch_size', [1]) + @pytest.mark.parametrize('num_heads', [12, 3]) + @pytest.mark.parametrize('head_dim', [128]) + @pytest.mark.parametrize('length', [4096, 8192, 8192 * 2]) + @pytest.mark.parametrize('dtype', [torch.bfloat16]) + @pytest.mark.parametrize('forward_only', [False]) + @pytest.mark.parametrize('attend_current', [False, True]) + def test_varlen(self, batch_size, num_heads, head_dim, length, attend_current, dtype, forward_only): + set_seed(1337) + torch.set_printoptions(linewidth=110, edgeitems=30) + device = torch.device('cuda:0') + lengths = torch.randint(length, length + 1, (batch_size,)).to(device=device, dtype=torch.int32) + print(lengths) + total_length = lengths.sum() + cu_seqlens = torch.cumsum(lengths, dim=-1) + v = 0.25 * torch.randn((num_heads, total_length, head_dim), device=device, dtype=torch.float32) + q = 0.25 * (torch.randn((num_heads, total_length, head_dim), device=device, dtype=torch.float32) + 1) + k = 0.25 * (torch.randn((num_heads, total_length, head_dim), device=device, dtype=torch.float32) - 1) + print(q.max(), k.max(), v.max()) + q = q.to(dtype) + k = k.to(dtype) + v = v.to(dtype) + q.requires_grad_() + k.requires_grad_() + v.requires_grad_() + do = torch.randn((num_heads, total_length, head_dim), device=device, dtype=dtype) + with torch.cuda.device(device): + o, rem = sb_attn_varlen(q, k, v, + cu_seqlens=cu_seqlens, + max_seqlens=torch.max(lengths).item(), + inv_temp=1 / math.sqrt(q.size(-1)), + zero_start=False, + attend_current=attend_current) + o = o + rem[..., None] * v + ref_out, ref_dq, ref_dk, ref_dv = ref_bwd(do, q, k, v, lengths, attend_current=attend_current) + eps = 0.05 + torch.cuda.synchronize() + assert_close("o", ref_out, o, eps) + if not forward_only: + dq, dk, dv = torch.autograd.grad(o, inputs=(q, k, v), grad_outputs=do) + assert_close("dq", ref_dq, dq, eps) + assert_close("dk", ref_dk, dk, eps) + assert_close("dv", ref_dv, dv, eps) diff --git a/tests/ops/test_stickbreaking_attn.py b/tests/ops/test_stickbreaking_attn.py index f3baf876a..056b375cc 100644 --- a/tests/ops/test_stickbreaking_attn.py +++ b/tests/ops/test_stickbreaking_attn.py @@ -12,8 +12,10 @@ [ pytest.param(*test, id="B{}-T{}-H{}-D{}-{}".format(*test)) for test in [ - (2, 64, 2, 64, torch.float32), - (1, 128, 4, 64, torch.bfloat16), + (2, 128, 2, 64, torch.float16), + (1, 256, 4, 64, torch.float16), + (2, 512, 4, 64, torch.float16), + (4, 1024, 4, 128, torch.float16), ] ], ) @@ -40,7 +42,7 @@ def test_stickbreaking_attn( scale = 1.0 / math.sqrt(D) # Reference (naive) - ref_o, ref_rem = naive_stickbreaking_attn(q, k, v, scale, attend_current=False) + ref_o, ref_rem = naive_stickbreaking_attn(q, k, v, scale) (ref_o * do).sum().backward(retain_graph=True) (ref_rem * dr).sum().backward() ref_dq, q.grad = q.grad.clone(), None @@ -48,7 +50,7 @@ def test_stickbreaking_attn( ref_dv, v.grad = v.grad.clone(), None # Triton fused - tri_o, tri_rem = parallel_stickbreaking_attn(q, k, v, scale=scale, attend_current=False) + tri_o, tri_rem = parallel_stickbreaking_attn(q, k, v, scale=scale) (tri_o * do).sum().backward(retain_graph=True) (tri_rem * dr).sum().backward() tri_dq, q.grad = q.grad.clone(), None @@ -61,3 +63,76 @@ def test_stickbreaking_attn( assert_close("dq", ref_dq, tri_dq, 0.02) assert_close("dk", ref_dk, tri_dk, 0.02) assert_close("dv", ref_dv, tri_dv, 0.02) + + +@pytest.mark.parametrize( + ('H', 'D', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="H{}-D{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (2, 64, [0, 63], torch.float16), + (4, 64, [0, 256, 500, 1000], torch.float16), + (4, 128, [0, 15, 100, 300, 1200, 2000], torch.float16), + (2, 128, [0, 100, 123, 300, 500, 800, 1000, 1500, 2048], torch.float16), + ] + ], +) +@pytest.mark.skipif( + is_intel_alchemist, + reason="Skipping test on Intel Alchemist due to known issues with SRAM.", +) +def test_stickbreaking_attn_varlen( + H: int, + D: int, + cu_seqlens: list[int], + dtype: torch.dtype, +): + torch.manual_seed(42) + + T = cu_seqlens[-1] + num_chunks = len(cu_seqlens) - 1 + cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + + q = torch.randn((1, T, H, D), dtype=dtype, device=device).requires_grad_(True) + k = torch.randn((1, T, H, D), dtype=dtype, device=device).requires_grad_(True) + v = torch.randn((1, T, H, D), dtype=dtype, device=device).requires_grad_(True) + + do = torch.randn_like(q) + dr = torch.randn((1, T, H), dtype=dtype, device=device) + + scale = 1.0 / math.sqrt(D) + + ref_os = [] + ref_rems = [] + for idx in range(num_chunks): + start, end = cu_seqlens[idx], cu_seqlens[idx + 1] + ref_o_chunk, ref_rem_chunk = naive_stickbreaking_attn( + q[:, start:end], + k[:, start:end], + v[:, start:end], + scale, + ) + ref_os.append(ref_o_chunk) + ref_rems.append(ref_rem_chunk) + + ref_o = torch.cat(ref_os, dim=1) + ref_rem = torch.cat(ref_rems, dim=1) + + (ref_o * do).sum().backward(retain_graph=True) + (ref_rem * dr).sum().backward() + ref_dq, q.grad = q.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dv, v.grad = v.grad.clone(), None + + tri_o, tri_rem = parallel_stickbreaking_attn(q, k, v, scale=scale, cu_seqlens=cu_seqlens_tensor) + (tri_o * do).sum().backward(retain_graph=True) + (tri_rem * dr).sum().backward() + tri_dq, q.grad = q.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dv, v.grad = v.grad.clone(), None + + assert_close("o", ref_o, tri_o, 0.008) + assert_close("rem", ref_rem, tri_rem, 0.02) + assert_close("dq", ref_dq, tri_dq, 0.02) + assert_close("dk", ref_dk, tri_dk, 0.02) + assert_close("dv", ref_dv, tri_dv, 0.02) From 18d3fb2d264a8c5b7a911c8f241c912d6b000fc4 Mon Sep 17 00:00:00 2001 From: Nathancgy4 Date: Mon, 10 Nov 2025 12:32:16 +0000 Subject: [PATCH 08/10] removed unnecessary files --- sba_code/.gitignore | 162 ----- sba_code/LICENSE | 201 ------ sba_code/README.md | 32 - sba_code/benchmarks/attn.py | 96 --- sba_code/benchmarks/varlen.py | 129 ---- sba_code/load_model_with_dolomite_demo.py | 7 - sba_code/setup.py | 29 - sba_code/stickbreaking_attention/__init__.py | 2 - .../sb_attn/__init__.py | 64 -- .../stickbreaking_attention/sb_attn/sb_bwd.py | 297 -------- .../stickbreaking_attention/sb_attn/sb_fwd.py | 253 ------- sba_code/stickbreaking_attention/sb_ref.py | 25 - .../sb_varlen/__init__.py | 82 --- .../sb_varlen/sb_varlen_bwd.py | 641 ------------------ .../sb_varlen/sb_varlen_fwd.py | 522 -------------- .../sb_varlen/softplus.py | 52 -- sba_code/stickbreaking_attention/utils.py | 39 -- sba_code/tests/__init__.py | 0 sba_code/tests/test_attn.py | 74 -- sba_code/tests/test_varlen.py | 110 --- 20 files changed, 2817 deletions(-) delete mode 100644 sba_code/.gitignore delete mode 100644 sba_code/LICENSE delete mode 100644 sba_code/README.md delete mode 100644 sba_code/benchmarks/attn.py delete mode 100644 sba_code/benchmarks/varlen.py delete mode 100644 sba_code/load_model_with_dolomite_demo.py delete mode 100644 sba_code/setup.py delete mode 100644 sba_code/stickbreaking_attention/__init__.py delete mode 100644 sba_code/stickbreaking_attention/sb_attn/__init__.py delete mode 100644 sba_code/stickbreaking_attention/sb_attn/sb_bwd.py delete mode 100644 sba_code/stickbreaking_attention/sb_attn/sb_fwd.py delete mode 100644 sba_code/stickbreaking_attention/sb_ref.py delete mode 100644 sba_code/stickbreaking_attention/sb_varlen/__init__.py delete mode 100644 sba_code/stickbreaking_attention/sb_varlen/sb_varlen_bwd.py delete mode 100644 sba_code/stickbreaking_attention/sb_varlen/sb_varlen_fwd.py delete mode 100644 sba_code/stickbreaking_attention/sb_varlen/softplus.py delete mode 100644 sba_code/stickbreaking_attention/utils.py delete mode 100644 sba_code/tests/__init__.py delete mode 100644 sba_code/tests/test_attn.py delete mode 100644 sba_code/tests/test_varlen.py diff --git a/sba_code/.gitignore b/sba_code/.gitignore deleted file mode 100644 index 82f927558..000000000 --- a/sba_code/.gitignore +++ /dev/null @@ -1,162 +0,0 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ -cover/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# poetry -# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock - -# pdm -# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. -#pdm.lock -# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it -# in version control. -# https://pdm.fming.dev/latest/usage/project/#working-with-version-control -.pdm.toml -.pdm-python -.pdm-build/ - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ diff --git a/sba_code/LICENSE b/sba_code/LICENSE deleted file mode 100644 index 261eeb9e9..000000000 --- a/sba_code/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/sba_code/README.md b/sba_code/README.md deleted file mode 100644 index ee44de9d3..000000000 --- a/sba_code/README.md +++ /dev/null @@ -1,32 +0,0 @@ -# Stick-breaking Attention Implementation -Triton-based implementation of Stick-breaking Attention on GPUs. -This implementation is for variable length . -You can find the paper [here](https://arxiv.org/abs/2410.17980) - -## Installation -```sh -# Install editable. This will allow you to modify stickbreaking in this directory. -pip install -e . -# Check all is working well. -pytest -x tests -``` -### Usage -#### Variable Length Attention -Each mini-batch consists of concatenated sequences of different lengths. - -`sb_attn_varlen` implements the counterpart to Flash Attention's -[`flash_attn_varlen_func`](https://github.com/Dao-AILab/flash-attention/blob/0dfb28174333d9eefb7c1dd4292690a8458d1e89/flash_attn/flash_attn_interface.py#L1360). -Assuming we have an input batch that concatenates all documents/sequences into a long array, and the corresponding -sequence lengths in the batch in an array `lengths`. -Then we can compute the cu_seqlens and pass that to `sb_attn_varlen`: -```python -import torch -from stickbreaking_attention.sb_varlen import sb_attn_varlen -# lengths: batch_size, -total_length = torch.sum(lengths) -# q, k, v: num_heads, total_length, head_dima -cu_seqlens = torch.cumsum(lengths) -o, rem = sb_attn_varlen(q, k, v, cu_seqlens, zero_start=False) -``` - -Enjoy! diff --git a/sba_code/benchmarks/attn.py b/sba_code/benchmarks/attn.py deleted file mode 100644 index cc9458416..000000000 --- a/sba_code/benchmarks/attn.py +++ /dev/null @@ -1,96 +0,0 @@ -import torch -import pytest -import math -from torch.nn import functional as F -from stickbreaking_attention.sb_attn import sb_attn -import triton -from flash_attn import flash_attn_func -from flash_attn.flash_attn_triton import flash_attn_func as triton_flash_attn_func -from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb, rotate_half -from transformers import set_seed - - -def tri_fwdbwd(do, q, k, v): - q = q.permute(0, 2, 1, 3) - k = k.permute(0, 2, 1, 3) - v = v.permute(0, 2, 1, 3) - o, rem = sb_attn(q, k, v, inv_temp=1 / math.sqrt(q.size(-1))) - o = o.permute(0, 2, 1, 3) - # o = o + rem[..., None] * v - return o - -def flash_fwdbwd(rope, position_ids, do, q, k, v): - cos, sin = rope(v, position_ids) - cos = cos.unsqueeze(-2) - sin = sin.unsqueeze(-2) - q = (q * cos) + (rotate_half(q) * sin) - k = (k * cos) + (rotate_half(k) * sin) - o = flash_attn_func(q, k, v, causal=True) - # o = o.permute(0, 2, 1, 3) - return o - -def triton_flash_fwdbwd(rope, position_ids, do, q, k, v): - cos, sin = rope(v, position_ids) - cos = cos.unsqueeze(-2) - sin = sin.unsqueeze(-2) - q = (q * cos) + (rotate_half(q) * sin) - k = (k * cos) + (rotate_half(k) * sin) - o = triton_flash_attn_func(q, k, v, None, True) - # o = o.permute(0, 2, 1, 3) - return o - - -providers = [ - ("triton", "Stickbreaking", ("blue", "-")), - ("flash", "Flash Attention", ("green", "-")), - # ("triton_flash", "Triton Flash", ("red", "-")), # triton flash not working -] -@triton.testing.perf_report([ - triton.testing.Benchmark( - x_names=["length"], - x_vals=[4096, 2 * 4096, 3 * 4096, 4 * 4096], - line_arg="provider", - line_vals=[x[0] for x in providers], - line_names=[x[1] for x in providers], - styles=[x[2] for x in providers], - ylabel="ms", - plot_name=f"triton v torch", - args={"batch_size": 4, "num_heads": 12, "head_dim": 128, "dtype": torch.bfloat16, "bwd": True} - ) -]) -def benchmark_attn(batch_size, num_heads, head_dim, length, dtype, provider, bwd): - device = torch.device('cuda:0') - set_seed(1337) - warmup = 100 - rep = 1000 - - q = torch.randn((batch_size, length, num_heads, head_dim), device=device, dtype=dtype) - k = torch.randn((batch_size, length, num_heads, head_dim), device=device, dtype=dtype) - v = torch.randn((batch_size, length, num_heads, head_dim), device=device, dtype=dtype) - q.requires_grad_() - k.requires_grad_() - v.requires_grad_() - do = torch.randn((batch_size, length, num_heads, head_dim), device=device, dtype=dtype) - position_ids = torch.arange(q.size(1), device=device, dtype=torch.int32)[None, :] - if provider == "triton": - fun = lambda: tri_fwdbwd(do, q, k, v) - elif provider == "flash": - rope = LlamaRotaryEmbedding(dim=head_dim).to(device) - fun = lambda: flash_fwdbwd(rope, position_ids, do, q, k, v) - elif provider == "triton_flash": - rope = LlamaRotaryEmbedding(dim=head_dim).to(device) - fun = lambda: triton_flash_fwdbwd(rope, position_ids, do, q, k, v) - - if bwd: - def fun_(): - o = fun() - dq, dk, dv = torch.autograd.grad(o, inputs=(q, k, v), grad_outputs=do) - - return triton.testing.do_bench(fun_, warmup=warmup, rep=rep) - else: - return triton.testing.do_bench(fun, warmup=warmup, rep=rep) - - - -if __name__ == "__main__": - benchmark_attn.run(save_path=None, print_data=True) diff --git a/sba_code/benchmarks/varlen.py b/sba_code/benchmarks/varlen.py deleted file mode 100644 index e04c805a4..000000000 --- a/sba_code/benchmarks/varlen.py +++ /dev/null @@ -1,129 +0,0 @@ -import torch -import pytest -import math -from torch.nn import functional as F -from stickbreaking_attention.sb_varlen import sb_attn_varlen -import triton -from flash_attn import flash_attn_varlen_func -from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb, rotate_half -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers import set_seed -from stickbreaking_attention.sb_ref import stickbreaking - - - -def ref_fwd(q, k, v, lengths): - q = q.permute(1, 0, 2) - k = k.permute(1, 0, 2) - v = v.permute(1, 0, 2) - splits = list(lengths.cpu().numpy()) - max_len = max(splits) - cm = torch.ones(max_len, max_len).tril(-1).to(q) - mask = torch.ones(max_len, max_len).triu(0).cuda().bool() - outputs = [] - for q_chunk, k_chunk, v_chunk in zip(q.split(splits, 1), k.split(splits, 1), v.split(splits, 1)): - len = q_chunk.size(1) - o, rem = stickbreaking( - q_chunk[None, :], - k_chunk[None, :], - v_chunk[None, :], - mask[:len, :len], cm[:len, :len] - ) - - # o = o + rem[..., None] * v_chunk[None] - outputs.append(o[0]) - return torch.cat(outputs, 1) - -def ref_fwdbwd(do, q, k, v, lengths): - o = ref_fwd(q, k, v, lengths) - return o - - -def tri_fwdbwd(do, q, k, v, lengths): - q = q.permute(1, 0, 2) - k = k.permute(1, 0, 2) - v = v.permute(1, 0, 2) - cu_seqlens = torch.cumsum(lengths, dim=-1) - o, rem = sb_attn_varlen(q, k, v, - cu_seqlens=cu_seqlens, - max_seqlens=max(lengths).item(), - inv_temp=1 / math.sqrt(q.size(-1)), - zero_start=False) - # o = o + rem[..., None] * v - return o - -def flash_fwdbwd(rope, position_ids, do, q, k, v, lengths): - cos, sin = rope(v, position_ids) - q = (q * cos) + (rotate_half(q) * sin) - k = (k * cos) + (rotate_half(k) * sin) - lengths = lengths.to(torch.int32) - cu_seqlens = torch.cumsum(lengths, dim=-1) - cu_seqlens = F.pad(cu_seqlens, (1, 0)).to(torch.int32) - max_len = torch.max(lengths) - o = flash_attn_varlen_func( - q, k, v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_len, - max_seqlen_k=max_len, - causal=True - ) - o = o.permute(1, 0, 2) - return o - - -providers = [ - # ("reference", "Stickbreaking (ref.)", ("red", "-")), - ("triton", "Stickbreaking", ("blue", "-")), - ("flash", "Flash Attention", ("green", "-")), -] -@triton.testing.perf_report([ - triton.testing.Benchmark( - x_names=["length"], - x_vals=[4096, 2 * 4096, 3 * 4096, 4 * 4096], - line_arg="provider", - line_vals=[x[0] for x in providers], - line_names=[x[1] for x in providers], - styles=[x[2] for x in providers], - ylabel="ms", - plot_name=f"triton v torch", - args={"batch_size": 4, "num_heads": 12, "head_dim": 128, "dtype": torch.bfloat16, "bwd": True} - ) -]) -def benchmark_varlen(batch_size, num_heads, head_dim, length, dtype, provider, bwd): - device = torch.device('cuda:0') - set_seed(1337) - lengths = torch.randint(length, length + 1, (batch_size,)).to(device=device, dtype=torch.int32) - total_length = lengths.sum() - warmup = 100 - rep = 1000 - - q = torch.randn((total_length, num_heads, head_dim), device=device, dtype=dtype) - k = torch.randn((total_length, num_heads, head_dim), device=device, dtype=dtype) - v = torch.randn((total_length, num_heads, head_dim), device=device, dtype=dtype) - q.requires_grad_() - k.requires_grad_() - v.requires_grad_() - do = torch.randn((num_heads, total_length, head_dim), device=device, dtype=dtype) - position_ids = torch.arange(q.size(1), device=device, dtype=torch.int32)[None, :] - - if provider== "reference": - fun = lambda: ref_fwdbwd(do, q, k, v, lengths) - elif provider == "triton": - fun = lambda: tri_fwdbwd(do, q, k, v, lengths) - elif provider == "flash": - config = LlamaConfig(max_position_embeddings=length) - rope = LlamaRotaryEmbedding(config).to(device) - fun = lambda: flash_fwdbwd(rope, position_ids, do, q, k, v, lengths) - if bwd: - def fun_(): - o = fun() - dq, dk, dv = torch.autograd.grad(o, inputs=(q, k, v), grad_outputs=do) - return triton.testing.do_bench(fun_, warmup=warmup, rep=rep) - else: - return triton.testing.do_bench(fun, warmup=warmup, rep=rep) - - - -if __name__ == "__main__": - benchmark_varlen.run(save_path=None, print_data=True) diff --git a/sba_code/load_model_with_dolomite_demo.py b/sba_code/load_model_with_dolomite_demo.py deleted file mode 100644 index 267918dfe..000000000 --- a/sba_code/load_model_with_dolomite_demo.py +++ /dev/null @@ -1,7 +0,0 @@ -import transformers -from dolomite_engine import hf_models - -if __name__ == "__main__": - hf_models = transformers.AutoModelForCausalLM.from_pretrained( - 'shawntan/stickbreaking-3b', - ) diff --git a/sba_code/setup.py b/sba_code/setup.py deleted file mode 100644 index 809fff7c8..000000000 --- a/sba_code/setup.py +++ /dev/null @@ -1,29 +0,0 @@ -import os -from setuptools import setup, find_packages - -def read(fname): - return open(os.path.join(os.path.dirname(__file__), fname)).read() - -setup( - name = "stickbreaking_attention", - version = "0.0.0", - author = "Shawn Tan", - author_email = "shawntan@ibm.com", - description = "Triton implementation of Stick-breaking attention", - license = "Apache License", - keywords = "triton pytorch llm stickbreaking attention", - url = "https://github.com/shawntan/scattermoe", - packages=find_packages(), - long_description=read('README.md'), - python_requires='>=3.10.10', - install_requires=[ - 'torch', - 'triton', - ], - tests_require=['pytest', 'numpy'], - classifiers=[ - "Development Status :: 1 - Planning", - "License :: OSI Approved :: Apache Software License", - ], -) - diff --git a/sba_code/stickbreaking_attention/__init__.py b/sba_code/stickbreaking_attention/__init__.py deleted file mode 100644 index e798ad2f8..000000000 --- a/sba_code/stickbreaking_attention/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .sb_attn import sb_attn -from .sb_varlen import sb_attn_varlen diff --git a/sba_code/stickbreaking_attention/sb_attn/__init__.py b/sba_code/stickbreaking_attention/sb_attn/__init__.py deleted file mode 100644 index 22d55b583..000000000 --- a/sba_code/stickbreaking_attention/sb_attn/__init__.py +++ /dev/null @@ -1,64 +0,0 @@ -import math - -import torch -import triton.language as tl - -from .sb_bwd import _bwd -from .sb_fwd import _fwd - - -FWD_BLOCK_M: tl.constexpr = 64 -FWD_BLOCK_N: tl.constexpr = 32 -BWD_BLOCK_M: tl.constexpr = 64 -BWD_BLOCK_N: tl.constexpr = 32 - - -class StickBreakingAttention(torch.autograd.Function): - - @staticmethod - def forward(ctx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, inv_temp: float, - attend_current: bool = False): - no_grad = not ctx.needs_input_grad[0] - logit_scale = inv_temp - BLOCK_M = FWD_BLOCK_M - BLOCK_N = FWD_BLOCK_N - o, rem, neg_log_acc = _fwd( - q, k, v, logit_scale=inv_temp, no_grad=no_grad, - return_attention=False, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - attend_current=attend_current - ) - ctx.save_for_backward(q, k, v, neg_log_acc) - ctx.logit_scale = logit_scale - ctx.attend_current = attend_current - return o, rem - - @staticmethod - def backward(ctx, do: torch.Tensor, drem: torch.Tensor): - logit_scale = ctx.logit_scale - attend_current = ctx.attend_current - q, k, v, neg_log_acc = ctx.saved_tensors - BLOCK_M = BWD_BLOCK_M - BLOCK_N = BWD_BLOCK_N - dq, dk, dv = _bwd( - do, - drem, - q, - k, - v, - neg_log_acc, - logit_scale, - attend_current=attend_current, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - ) - return dq, dk, dv, None, None - - -def sb_attn(q, k, v, inv_temp=None, zero_start=True, attend_current=False): - if inv_temp is None: - inv_temp = 1 / math.sqrt(q.size(-1)) - return sb_attn_(q, k, v, inv_temp, attend_current=attend_current) - - -def sb_attn_(q, k, v, inv_temp, attend_current): - return StickBreakingAttention.apply(q, k, v, inv_temp, attend_current) diff --git a/sba_code/stickbreaking_attention/sb_attn/sb_bwd.py b/sba_code/stickbreaking_attention/sb_attn/sb_bwd.py deleted file mode 100644 index 8b881c5a0..000000000 --- a/sba_code/stickbreaking_attention/sb_attn/sb_bwd.py +++ /dev/null @@ -1,297 +0,0 @@ -import torch -import triton -import triton.language as tl - -from ..utils import ALLOW_TF32, inv_log2, custom_op -from ..sb_varlen.sb_varlen_bwd import _backward_one_row -from ..sb_varlen.sb_varlen_fwd import compute_block, load_kv - - -def get_configs(): - return [triton.Config({}, num_stages=s, num_warps=w) for s in [8] for w in [4]] - - -@triton.autotune( - configs=get_configs(), - key=["token_size", "head_size"], -) -# reset_to_zero=["DK_ptr", "DV_ptr"]) -@triton.jit() -def _backward( - DO_ptr, - stride_dob, - stride_doh, - stride_dom: tl.constexpr, - stride_dod: tl.constexpr, - DR_ptr, - stride_drb, - stride_drh, - stride_drm: tl.constexpr, - A_ptr, - stride_ab, - stride_ah, - stride_am: tl.constexpr, - Q_ptr, - stride_qb, - stride_qh, - stride_qm: tl.constexpr, - stride_qd: tl.constexpr, - K_ptr, - stride_kb, - stride_kh, - stride_kn: tl.constexpr, - stride_kd: tl.constexpr, - V_ptr, - stride_vb, - stride_vh, - stride_vn: tl.constexpr, - stride_vd: tl.constexpr, - DQ_ptr, - stride_dqb, - stride_dqh, - stride_dqm: tl.constexpr, - stride_dqd: tl.constexpr, - DK_ptr, - stride_dkb, - stride_dkh, - stride_dkn: tl.constexpr, - stride_dkd: tl.constexpr, - DV_ptr, - stride_dvb, - stride_dvh, - stride_dvn: tl.constexpr, - stride_dvd: tl.constexpr, - KV_Lock_ptr, - KV_Count_ptr, - stride_kvb: tl.constexpr, - stride_kvl: tl.constexpr, - logit_scale, - batch_size, - token_size, - head_size: tl.constexpr, - num_heads: tl.constexpr, - BLOCK_D: tl.constexpr, - NO_D_MASK: tl.constexpr, - NO_M_MASK: tl.constexpr, - NO_N_MASK: tl.constexpr, - ALLOW_TF32: tl.constexpr, - inv_log2: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - acc_dtype: tl.constexpr = tl.float32, - is_compiling: tl.constexpr = False, - attend_current: tl.constexpr = False, -): - tl.static_assert(BLOCK_M % BLOCK_N == 0) - batch_id = tl.program_id(0) - head_pid = tl.program_id(1) - prog_id = tl.program_id(2) - # Universal stuff - qk_scale = inv_log2 * logit_scale - M_range = tl.arange(0, BLOCK_M) - N_range = tl.arange(0, BLOCK_N) - D_range = tl.arange(0, BLOCK_D) - D_mask = D_range < head_size - cm = tl.where(N_range[:, None] >= N_range[None, :], - 1.0, 0.0).to(Q_ptr.type.element_ty) - - head_id = head_pid - seq_prog_id = prog_id - seq_length = token_size - - DO_head_seq_ptr = DO_ptr + stride_dob * batch_id + stride_doh * head_id - DR_head_seq_ptr = DR_ptr + stride_drb * batch_id + stride_drh * head_id - A_head_seq_ptr = A_ptr + stride_ab * batch_id + stride_ah * head_id - Q_head_seq_ptr = Q_ptr + stride_qb * batch_id + stride_qh * head_id - K_head_seq_ptr = K_ptr + stride_kb * batch_id + stride_kh * head_id - V_head_seq_ptr = V_ptr + stride_vb * batch_id + stride_vh * head_id - DQ_head_seq_ptr = DQ_ptr + stride_dqb * batch_id + stride_dqh * head_id - DK_head_seq_ptr = DK_ptr + stride_dkb * batch_id + stride_dkh * head_id - DV_head_seq_ptr = DV_ptr + stride_dvb * batch_id + stride_dvh * head_id - KV_Lock_head_seq_ptr = KV_Lock_ptr + stride_kvb * batch_id + stride_kvl * head_id - KV_Count_head_seq_ptr = KV_Count_ptr + \ - stride_kvb * batch_id + stride_kvl * head_id - _backward_one_row( - seq_prog_id, - seq_length, - qk_scale, - M_range, - N_range, - D_range, - D_mask, - cm, - DO_head_seq_ptr, - stride_dom, - stride_dod, - DR_head_seq_ptr, - stride_drm, - A_head_seq_ptr, - stride_am, - Q_head_seq_ptr, - stride_qm, - stride_qd, - K_head_seq_ptr, - stride_kn, - stride_kd, - V_head_seq_ptr, - stride_vn, - stride_vd, - DQ_head_seq_ptr, - stride_dqm, - stride_dqd, - DK_head_seq_ptr, - stride_dkn, - stride_dkd, - DV_head_seq_ptr, - stride_dvn, - stride_dvd, - KV_Lock_head_seq_ptr, - KV_Count_head_seq_ptr, - logit_scale, - BLOCK_D, - NO_D_MASK, - NO_M_MASK, - ALLOW_TF32, - BLOCK_M, - BLOCK_N, - acc_dtype, - is_compiling=is_compiling, - attend_current=attend_current, - ) - - -def _bwd(do, dr, q, k, v, neg_log_acc, logit_scale, - attend_current=False, BLOCK_M=64, BLOCK_N=32): - batch_size, num_heads, token_size, dim_size = q.size() - M_count = triton.cdiv(token_size, BLOCK_M) - N_count = triton.cdiv(token_size, BLOCK_N) - - # dqdkdv = torch.zeros((batch_size, token_size, num_heads, 3 * dim_size), device=do.device, dtype=do.dtype) - # dqdkdv = dqdkdv.permute(0, 2, 1, 3) - # dq, dk, dv = dqdkdv.chunk(3, dim=-1) - dq = torch.zeros_like(q) - dk = torch.zeros_like(k, dtype=torch.bfloat16) - dv = torch.zeros_like(v, dtype=torch.bfloat16) - - M_count = triton.cdiv(token_size, BLOCK_M) - N_count = M_count * (BLOCK_M // BLOCK_N) - dkdv_lock = torch.zeros((batch_size, num_heads, N_count), - dtype=torch.int32, device=q.device) - dkdv_count = torch.zeros( - (batch_size, num_heads, N_count), dtype=torch.bool, device=q.device) - _compileable_backward( - do, - dr, - q, - k, - v, - neg_log_acc, - logit_scale, - attend_current, - BLOCK_M, - BLOCK_N, - batch_size, - num_heads, - token_size, - dim_size, - M_count, - N_count, - dq, - dk, - dv, - dkdv_lock, - dkdv_count, - ) - return dq, dk, dv - - -@custom_op("attn_bwd", mutates_args={"dq", "dk", "dv", "dkdv_lock", "dkdv_count"}) -def _compileable_backward( - do: torch.Tensor, - dr: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - neg_log_acc: torch.Tensor, - logit_scale: float, - attend_current: bool, - BLOCK_M: int, - BLOCK_N: int, - batch_size: int, - num_heads: int, - token_size: int, - dim_size: int, - M_count: int, - N_count: int, - dq: torch.Tensor, - dk: torch.Tensor, - dv: torch.Tensor, - dkdv_lock: torch.Tensor, - dkdv_count: torch.Tensor, -) -> None: - BLOCK_D = triton.next_power_of_2(dim_size) - _backward[batch_size, num_heads, M_count]( - do, - do.stride(0), - do.stride(1), - do.stride(2), - do.stride(3), - dr, - dr.stride(0), - dr.stride(1), - dr.stride(2), - neg_log_acc, - neg_log_acc.stride(0), - neg_log_acc.stride(1), - neg_log_acc.stride(2), - q, - q.stride(0), - q.stride(1), - q.stride(2), - q.stride(3), - k, - k.stride(0), - k.stride(1), - k.stride(2), - k.stride(3), - v, - v.stride(0), - v.stride(1), - v.stride(2), - v.stride(3), - dq, - dq.stride(0), - dq.stride(1), - dq.stride(2), - dq.stride(3), - dk, - dk.stride(0), - dk.stride(1), - dk.stride(2), - dk.stride(3), - dv, - dv.stride(0), - dv.stride(1), - dv.stride(2), - dv.stride(3), - dkdv_lock, - dkdv_count, - num_heads * N_count, - N_count, - logit_scale=logit_scale, - attend_current=attend_current, - batch_size=batch_size, - token_size=token_size, - head_size=dim_size, - num_heads=num_heads, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - BLOCK_D=BLOCK_D, - NO_D_MASK=BLOCK_D == dim_size, - NO_M_MASK=(token_size % BLOCK_M) == 0, - NO_N_MASK=(token_size % BLOCK_N) == 0, - ALLOW_TF32=ALLOW_TF32, - inv_log2=inv_log2, - acc_dtype=tl.float32, - is_compiling=False, - ) diff --git a/sba_code/stickbreaking_attention/sb_attn/sb_fwd.py b/sba_code/stickbreaking_attention/sb_attn/sb_fwd.py deleted file mode 100644 index b9f04cc2d..000000000 --- a/sba_code/stickbreaking_attention/sb_attn/sb_fwd.py +++ /dev/null @@ -1,253 +0,0 @@ -import torch -import triton -import triton.language as tl - -from ..utils import ALLOW_TF32, inv_log2, custom_op -from ..sb_varlen.sb_varlen_fwd import _forward_one_row -from ..sb_varlen.softplus import softplus - - - -def get_configs(): - return [triton.Config({}, num_stages=s, num_warps=w) for s in [4] for w in [4]] - - -@triton.autotune(configs=get_configs(), key=["token_size", "head_size"]) -@triton.jit -def _forward( - Q_ptr, - stride_qb, - stride_qh, - stride_qm: tl.constexpr, - stride_qd: tl.constexpr, - K_ptr, - stride_kb, - stride_kh, - stride_kn: tl.constexpr, - stride_kd: tl.constexpr, - V_ptr, - stride_vb, - stride_vh, - stride_vn: tl.constexpr, - stride_vd: tl.constexpr, - O_ptr, - stride_ob, - stride_oh, - stride_om: tl.constexpr, - stride_od: tl.constexpr, - R_ptr, - stride_rb, - stride_rh, - stride_rm: tl.constexpr, - A_ptr, - stride_ab, - stride_ah, - stride_am: tl.constexpr, - W_ptr, - stride_wb, - stride_wh, - stride_wm, - stride_wn, - logit_scale: tl.constexpr, - attend_current: tl.constexpr, - batch_size, - token_size, - head_size: tl.constexpr, - num_heads: tl.constexpr, - BLOCK_D: tl.constexpr, - NO_D_MASK: tl.constexpr, - NO_M_MASK: tl.constexpr, - NO_N_MASK: tl.constexpr, - ALLOW_TF32: tl.constexpr, - inv_log2: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - no_grad: tl.constexpr = False, - acc_dtype: tl.constexpr = tl.float32, - return_attention: tl.constexpr = False, - is_compiling: tl.constexpr = False, -): - tl.static_assert(BLOCK_M % BLOCK_N == 0) - batch_id = tl.program_id(0) - head_pid = tl.program_id(1) - prog_id = tl.program_id(2) - tl.num_programs(2) - seq_length = token_size - # Universal stuff - qk_scale = inv_log2 * logit_scale - M_range = tl.arange(0, BLOCK_M) - N_range = tl.arange(0, BLOCK_N) - D_range = tl.arange(0, BLOCK_D) - D_mask = D_range < head_size - cm = tl.where(N_range[:, None] >= N_range[None, :], - 1.0, 0.0).to(Q_ptr.type.element_ty) - - # First head block - head_id = head_pid - seq_prog_id = prog_id - # tl.store(pid_debug_ptr + head_id * tl.num_programs(1) + prog_id_start_offset + seq_prog_id, pid) - Q_head_seq_ptr = Q_ptr + stride_qb * batch_id + stride_qh * head_id - K_head_seq_ptr = K_ptr + stride_kb * batch_id + stride_kh * head_id - V_head_seq_ptr = V_ptr + stride_vb * batch_id + stride_vh * head_id - O_head_seq_ptr = O_ptr + stride_ob * batch_id + stride_oh * head_id - R_head_seq_ptr = R_ptr + stride_rb * batch_id + stride_rh * head_id - A_head_seq_ptr = A_ptr + stride_ab * batch_id + stride_ah * head_id - W_head_seq_ptr = W_ptr + stride_wb * batch_id + stride_wh * head_id - _forward_one_row( - seq_prog_id, - seq_length, - qk_scale, - M_range, - N_range, - D_range, - D_mask, - cm, - Q_head_seq_ptr, - stride_qm, - stride_qd, - K_head_seq_ptr, - stride_kn, - stride_kd, - V_head_seq_ptr, - stride_vn, - stride_vd, - O_head_seq_ptr, - stride_om, - stride_od, - R_head_seq_ptr, - stride_rm, - A_head_seq_ptr, - stride_am, - W_head_seq_ptr, - stride_wm, - stride_wn, - BLOCK_D, - NO_D_MASK, - NO_M_MASK, - NO_N_MASK, - ALLOW_TF32, - BLOCK_M, - BLOCK_N, - no_grad, - acc_dtype, - return_attention, - attend_current=attend_current, - is_compiling=is_compiling, - ) - - -def _fwd(q, k, v, logit_scale, - attend_current=False, - no_grad=False, return_attention=False, - BLOCK_M: int = 64, BLOCK_N: int = 32): - batch_size, num_heads, token_size, dim_size = q.size() - o = torch.empty_like(q) - rem = torch.zeros_like(q[:, :, :, 0], device=q.device) - neg_log_acc = torch.zeros_like(rem, device=q.device, dtype=torch.float32) - if return_attention: - W = torch.full((batch_size, num_heads, token_size, token_size), - 0.0, dtype=torch.float32, device=q.device) - else: - W = torch.empty((1, 1, 1, 1), device=q.device) - _compileable_fwd( - q, - k, - v, - logit_scale, - no_grad, - return_attention, - BLOCK_M, - BLOCK_N, - batch_size, - num_heads, - token_size, - dim_size, - o, - rem, - neg_log_acc, - W, - attend_current=attend_current, - ) - if return_attention: - return o, rem, neg_log_acc, W - else: - return o, rem, neg_log_acc - - -@custom_op("attn_fwd", mutates_args={"o", "rem", "neg_log_acc", "W"}) -def _compileable_fwd( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - logit_scale: float, - no_grad: bool, - return_attention: bool, - BLOCK_M: int, - BLOCK_N: int, - batch_size: int, - num_heads: int, - token_size: int, - dim_size: int, - o: torch.Tensor, - rem: torch.Tensor, - neg_log_acc: torch.Tensor, - W: torch.Tensor, - attend_current: bool, -) -> None: - num_folded_heads = num_heads - num_seq_blocks = triton.cdiv(token_size, BLOCK_M) - BLOCK_D = triton.next_power_of_2(dim_size) - grid = (batch_size, num_folded_heads, num_seq_blocks) - _forward[grid]( - q, - q.stride(0), - q.stride(1), - q.stride(2), - q.stride(3), - k, - k.stride(0), - k.stride(1), - k.stride(2), - k.stride(3), - v, - v.stride(0), - v.stride(1), - v.stride(2), - v.stride(3), - o, - o.stride(0), - o.stride(1), - o.stride(2), - o.stride(3), - rem, - rem.stride(0), - rem.stride(1), - rem.stride(2), - neg_log_acc, - neg_log_acc.stride(0), - neg_log_acc.stride(1), - neg_log_acc.stride(2), - W, - W.stride(0), - W.stride(1), - W.stride(2), - W.stride(3), - logit_scale=logit_scale, - batch_size=batch_size, - token_size=token_size, - head_size=dim_size, - num_heads=num_heads, - no_grad=no_grad, - attend_current=attend_current, - BLOCK_D=BLOCK_D, - NO_D_MASK=BLOCK_D == dim_size, - NO_M_MASK=(token_size % BLOCK_M) == 0, - NO_N_MASK=(token_size % BLOCK_N) == 0, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - ALLOW_TF32=ALLOW_TF32, - inv_log2=inv_log2, - return_attention=return_attention, - acc_dtype=tl.float32, - is_compiling=False, - ) diff --git a/sba_code/stickbreaking_attention/sb_ref.py b/sba_code/stickbreaking_attention/sb_ref.py deleted file mode 100644 index 7e6dc12f2..000000000 --- a/sba_code/stickbreaking_attention/sb_ref.py +++ /dev/null @@ -1,25 +0,0 @@ -import math - -import torch -from torch.nn import functional as F - - -# for reference -def stickbreaking(q, k, v, mask, cum_weight): - """ - Stick-breaking attention weights. - """ - logits = (q @ k.transpose(-1, -2)) / math.sqrt(q.shape[-1]) - - original_dtype = logits.dtype - - logits = logits.float() - log_z = F.logsigmoid(logits).masked_fill(mask, -1e5).to(original_dtype) - - log_beta = F.logsigmoid(-logits).masked_fill(mask, 0).to(original_dtype) - - re_cum_log_beta = torch.einsum( - "bhij,jk->bhik", log_beta, cum_weight.to(log_beta)) - log_att = log_z + re_cum_log_beta - att = log_att.exp() - return att @ v, 1 - att.sum(dim=-1) diff --git a/sba_code/stickbreaking_attention/sb_varlen/__init__.py b/sba_code/stickbreaking_attention/sb_varlen/__init__.py deleted file mode 100644 index 9a7096504..000000000 --- a/sba_code/stickbreaking_attention/sb_varlen/__init__.py +++ /dev/null @@ -1,82 +0,0 @@ -from .sb_varlen_fwd import varlen_fwd -from .sb_varlen_bwd import varlen_bwd -import math - -import torch -import triton.language as tl -from torch.nn import functional as F - - -FWD_BLOCK_M: tl.constexpr = 64 -FWD_BLOCK_N: tl.constexpr = 32 -BWD_BLOCK_M: tl.constexpr = 64 -BWD_BLOCK_N: tl.constexpr = 32 - - -def calculate_programs_needed(cu_seqlens: torch.Tensor, BLOCK_SIZE): - lens = cu_seqlens.clone() - lens[1:] -= cu_seqlens[:-1] - seq_num_programs = ((lens - 1) // BLOCK_SIZE) + 1 - seq_program_offsets = torch.cumsum(seq_num_programs, dim=0) - return seq_program_offsets - - -class StickBreakingAttention(torch.autograd.Function): - - @staticmethod - def forward(ctx, q, k, v, cu_seqlens, max_seqlens, inv_temp, attend_current): - no_grad = not ctx.needs_input_grad[0] - logit_scale = inv_temp - o, rem, neg_log_acc = varlen_fwd( - q, - k, - v, - cu_seqlens, - max_seqlens, - logit_scale=inv_temp, - attend_current=attend_current, - no_grad=no_grad, - BLOCK_M=FWD_BLOCK_M, - BLOCK_N=FWD_BLOCK_N, - ) - ctx.save_for_backward(q, k, v, neg_log_acc, cu_seqlens) - ctx.logit_scale = logit_scale - ctx.max_seqlens = max_seqlens - ctx.attend_current = attend_current - return o, rem - - @staticmethod - def backward(ctx, do, drem): - logit_scale = ctx.logit_scale - max_seqlens = ctx.max_seqlens - attend_current = ctx.attend_current - q, k, v, neg_log_acc, cu_seqlens = ctx.saved_tensors - dq, dk, dv = varlen_bwd( - do, - drem, - q, - k, - v, - cu_seqlens, - max_seqlens, - neg_log_acc, - logit_scale, - attend_current=attend_current, - BLOCK_M=BWD_BLOCK_M, - BLOCK_N=BWD_BLOCK_N, - ) - return dq, dk, dv, None, None, None, None - - -def sb_attn_varlen(q, k, v, cu_seqlens, max_seqlens, inv_temp=None, zero_start=True, attend_current=False): - if zero_start: - assert cu_seqlens[0] == 0 - cu_seqlens = cu_seqlens[1:] - if inv_temp is None: - inv_temp = 1 / math.sqrt(q.size(-1)) - - return sb_attn_varlen_(q, k, v, inv_temp, cu_seqlens, max_seqlens, attend_current) - - -def sb_attn_varlen_(q, k, v, inv_temp, cu_seqlens, max_seqlens, attend_current): - return StickBreakingAttention.apply(q, k, v, cu_seqlens, max_seqlens, inv_temp, attend_current) diff --git a/sba_code/stickbreaking_attention/sb_varlen/sb_varlen_bwd.py b/sba_code/stickbreaking_attention/sb_varlen/sb_varlen_bwd.py deleted file mode 100644 index 1812b4788..000000000 --- a/sba_code/stickbreaking_attention/sb_varlen/sb_varlen_bwd.py +++ /dev/null @@ -1,641 +0,0 @@ -import math - -import torch -import triton -import triton.language as tl - -from ..utils import ALLOW_TF32, inv_log2 -from .sb_varlen_fwd import compute_block, load_kv - -from ..utils import custom_op - -@triton.jit -def locked_add(Lock_ptr, Count_ptr, A_ptrs, a, B_ptrs, b, N_mask, NO_N_MASK, D_mask, NO_D_MASK: tl.constexpr, - EVICTION_POLICY: tl.constexpr=tl.constexpr("")): - while tl.atomic_cas(Lock_ptr, 0, 1) == 1: - pass - # tl.device_print("Start locked add.") - count = tl.load(Count_ptr, eviction_policy=EVICTION_POLICY) - if NO_D_MASK: - if NO_N_MASK: - if count == 0: - tl.store(Count_ptr, 1, eviction_policy=EVICTION_POLICY) - else: - a += tl.load(A_ptrs, eviction_policy=EVICTION_POLICY) - b += tl.load(B_ptrs, eviction_policy=EVICTION_POLICY) - tl.store(A_ptrs, a, eviction_policy=EVICTION_POLICY) - tl.store(B_ptrs, b, eviction_policy=EVICTION_POLICY) - - else: - if count == 0: - tl.store(Count_ptr, 1, eviction_policy=EVICTION_POLICY) - else: - a += tl.load(A_ptrs, - mask=N_mask[:, None], eviction_policy=EVICTION_POLICY) - b += tl.load(B_ptrs, - mask=N_mask[:, None], eviction_policy=EVICTION_POLICY) - tl.store(A_ptrs, a, mask=N_mask[:, None], - eviction_policy=EVICTION_POLICY) - tl.store(B_ptrs, b, mask=N_mask[:, None], - eviction_policy=EVICTION_POLICY) - - else: - # if True: # TODO delete - mask = N_mask[:, None] & D_mask[None, :] - if count == 0: - tl.store(Count_ptr, 1, eviction_policy=EVICTION_POLICY) - else: - a += tl.load(A_ptrs, mask=mask, eviction_policy=EVICTION_POLICY) - b += tl.load(B_ptrs, mask=mask, eviction_policy=EVICTION_POLICY) - tl.store(A_ptrs, a, mask=mask, eviction_policy=EVICTION_POLICY) - tl.store(B_ptrs, b, mask=mask, eviction_policy=EVICTION_POLICY) - - # tl.device_print("End locked add.") - tl.atomic_xchg(Lock_ptr, 0) - -@triton.jit -def _locked_add(Lock_ptr, Count_ptr, A_ptrs, a, B_ptrs, b, N_mask, NO_N_MASK, D_mask, NO_D_MASK: tl.constexpr, - EVICTION_POLICY: tl.constexpr=""): - # count = tl.load(Count_ptr, eviction_policy=EVICTION_POLICY) - if NO_D_MASK: - if NO_N_MASK: - tl.atomic_add(A_ptrs, a) - tl.atomic_add(B_ptrs, b) - else: - tl.atomic_add(A_ptrs, a, mask=N_mask[:, None]) - tl.atomic_add(B_ptrs, b, mask=N_mask[:, None]) - else: - mask = N_mask[:, None] & D_mask[None, :] - tl.atomic_add(A_ptrs, a, mask=mask) - tl.atomic_add(B_ptrs, b, mask=mask) - - -def get_configs(): - return [triton.Config({}, num_stages=s, num_warps=w) - # for mb in [64, 128] - # for nb in [16, 32, 64] - # for s in [8, 7, 6, 5, 4, 3, 2] - # for w in [4 , 2]] - # for mb in [32] - # for nb in [32] - for s in [8] - for w in [4]] - - - -@triton.autotune( - configs=get_configs(), - key=["token_size", "head_size"], - reset_to_zero=["DK_ptr", "DV_ptr", "KV_Lock_ptr", "KV_Count_ptr"] -) -@triton.jit -def _backward( - DO_ptr, - stride_doh: tl.constexpr, - stride_dom, - stride_dod: tl.constexpr, - DR_ptr, - stride_drh, - stride_drm, - A_ptr, - stride_ah, - stride_am, - Q_ptr, - stride_qh: tl.constexpr, - stride_qm, - stride_qd: tl.constexpr, - K_ptr, - stride_kh: tl.constexpr, - stride_kn, - stride_kd: tl.constexpr, - V_ptr, - stride_vh: tl.constexpr, - stride_vn, - stride_vd: tl.constexpr, - DQ_ptr, - stride_dqh: tl.constexpr, - stride_dqm, - stride_dqd: tl.constexpr, - DK_ptr, - stride_dkh: tl.constexpr, - stride_dkn, - stride_dkd: tl.constexpr, - DV_ptr, - stride_dvh: tl.constexpr, - stride_dvn, - stride_dvd: tl.constexpr, - KV_Lock_ptr, - KV_Count_ptr, - stride_kvs, - stride_kvh, - CSL_ptr, - logit_scale, - batch_size, - token_size, - head_size: tl.constexpr, - num_heads: tl.constexpr, - BLOCK_D: tl.constexpr, - BLOCK_CSL: tl.constexpr, - NO_D_MASK: tl.constexpr, - NO_M_MASK: tl.constexpr, - NO_N_MASK: tl.constexpr, - ALLOW_TF32: tl.constexpr, - inv_log2: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - acc_dtype: tl.constexpr = tl.float32, - attend_current: tl.constexpr = False -): - tl.static_assert(BLOCK_M % BLOCK_N == 0) - seq_id = tl.program_id(0) - fhead_id = tl.program_id(1) - seq_alloc_prog_id = tl.program_id(2) - num_seq_alloc_progs = tl.num_programs(2) - if seq_id == 0: - seq_start_offset = 0 - else: - seq_start_offset = tl.load(CSL_ptr + seq_id - 1).to(tl.int32) - seq_end_offset = tl.load(CSL_ptr + seq_id).to(tl.int32) - seq_length = seq_end_offset - seq_start_offset - num_seq_blocks = tl.cdiv(seq_length, BLOCK_M) - - seq_a_block_id = num_seq_blocks - seq_alloc_prog_id - 1 - seq_b_block_id = seq_alloc_prog_id - (num_seq_alloc_progs - num_seq_blocks) - - if seq_a_block_id >= 0 or seq_b_block_id >= 0: - # Universal stuff - qk_scale = inv_log2 * logit_scale - M_range = tl.arange(0, BLOCK_M) - N_range = tl.arange(0, BLOCK_N) - D_range = tl.arange(0, BLOCK_D) - D_mask = D_range < head_size - cm = tl.where(N_range[:, None] >= N_range[None, :], - 1.0, 0.0).to(Q_ptr.type.element_ty) - - if seq_a_block_id >= 0: - head_id = fhead_id * 2 - DO_head_seq_ptr = DO_ptr + stride_doh * head_id + stride_dom * seq_start_offset - DR_head_seq_ptr = DR_ptr + stride_drh * head_id + stride_drm * seq_start_offset - A_head_seq_ptr = A_ptr + stride_ah * head_id + stride_am * seq_start_offset - Q_head_seq_ptr = Q_ptr + stride_qh * head_id + stride_qm * seq_start_offset - K_head_seq_ptr = K_ptr + stride_kh * head_id + stride_kn * seq_start_offset - V_head_seq_ptr = V_ptr + stride_vh * head_id + stride_vn * seq_start_offset - DQ_head_seq_ptr = DQ_ptr + stride_dqh * head_id + stride_dqm * seq_start_offset - DK_head_seq_ptr = DK_ptr + stride_dkh * head_id + stride_dkn * seq_start_offset - DV_head_seq_ptr = DV_ptr + stride_dvh * head_id + stride_dvn * seq_start_offset - KV_Lock_head_seq_ptr = KV_Lock_ptr + stride_kvs * seq_id + stride_kvh * head_id - KV_Count_head_seq_ptr = KV_Count_ptr + \ - stride_kvs * seq_id + stride_kvh * head_id - _backward_one_row( - seq_a_block_id, - seq_length, - qk_scale, - M_range, - N_range, - D_range, - D_mask, - cm, - DO_head_seq_ptr, - stride_dom, - stride_dod, - DR_head_seq_ptr, - stride_drm, - A_head_seq_ptr, - stride_am, - Q_head_seq_ptr, - stride_qm, - stride_qd, - K_head_seq_ptr, - stride_kn, - stride_kd, - V_head_seq_ptr, - stride_vn, - stride_vd, - DQ_head_seq_ptr, - stride_dqm, - stride_dqd, - DK_head_seq_ptr, - stride_dkn, - stride_dkd, - DV_head_seq_ptr, - stride_dvn, - stride_dvd, - KV_Lock_head_seq_ptr, - KV_Count_head_seq_ptr, - logit_scale, - BLOCK_D, - NO_D_MASK, - NO_M_MASK, - ALLOW_TF32, - BLOCK_M, - BLOCK_N, - acc_dtype, - attend_current=attend_current - ) - if seq_b_block_id >= 0 and fhead_id * 2 + 1 < num_heads: - head_id = fhead_id * 2 + 1 - DO_head_seq_ptr = DO_ptr + stride_doh * head_id + stride_dom * seq_start_offset - DR_head_seq_ptr = DR_ptr + stride_drh * head_id + stride_drm * seq_start_offset - A_head_seq_ptr = A_ptr + stride_ah * head_id + stride_am * seq_start_offset - Q_head_seq_ptr = Q_ptr + stride_qh * head_id + stride_qm * seq_start_offset - K_head_seq_ptr = K_ptr + stride_kh * head_id + stride_kn * seq_start_offset - V_head_seq_ptr = V_ptr + stride_vh * head_id + stride_vn * seq_start_offset - DQ_head_seq_ptr = DQ_ptr + stride_dqh * head_id + stride_dqm * seq_start_offset - DK_head_seq_ptr = DK_ptr + stride_dkh * head_id + stride_dkn * seq_start_offset - DV_head_seq_ptr = DV_ptr + stride_dvh * head_id + stride_dvn * seq_start_offset - KV_Lock_head_seq_ptr = KV_Lock_ptr + stride_kvs * seq_id + stride_kvh * head_id - KV_Count_head_seq_ptr = KV_Count_ptr + \ - stride_kvs * seq_id + stride_kvh * head_id - _backward_one_row( - seq_b_block_id, - seq_length, - qk_scale, - M_range, - N_range, - D_range, - D_mask, - cm, - DO_head_seq_ptr, - stride_dom, - stride_dod, - DR_head_seq_ptr, - stride_drm, - A_head_seq_ptr, - stride_am, - Q_head_seq_ptr, - stride_qm, - stride_qd, - K_head_seq_ptr, - stride_kn, - stride_kd, - V_head_seq_ptr, - stride_vn, - stride_vd, - DQ_head_seq_ptr, - stride_dqm, - stride_dqd, - DK_head_seq_ptr, - stride_dkn, - stride_dkd, - DV_head_seq_ptr, - stride_dvn, - stride_dvd, - KV_Lock_head_seq_ptr, - KV_Count_head_seq_ptr, - logit_scale, - BLOCK_D, - NO_D_MASK, - NO_M_MASK, - ALLOW_TF32, - BLOCK_M, - BLOCK_N, - acc_dtype, - attend_current=attend_current - ) - - -@triton.jit -def _backward_one_row( - seq_prog_id, - seq_length, - qk_scale, - M_range, - N_range, - D_range, - D_mask, - cm, - DO_head_seq_ptr, - stride_dom, - stride_dod: tl.constexpr, - DR_head_seq_ptr, - stride_drm, - A_head_seq_ptr, - stride_am: tl.constexpr, - Q_head_seq_ptr, - stride_qm, - stride_qd: tl.constexpr, - K_head_seq_ptr, - stride_kn, - stride_kd: tl.constexpr, - V_head_seq_ptr, - stride_vn, - stride_vd: tl.constexpr, - DQ_head_seq_ptr, - stride_dqm, - stride_dqd: tl.constexpr, - DK_head_seq_ptr, - stride_dkn, - stride_dkd: tl.constexpr, - DV_head_seq_ptr, - stride_dvn, - stride_dvd: tl.constexpr, - KV_Lock_ptr, - KV_Count_ptr, - logit_scale, - BLOCK_D: tl.constexpr, - NO_D_MASK: tl.constexpr, - NO_M_MASK: tl.constexpr, - ALLOW_TF32: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - acc_dtype: tl.constexpr = tl.float32, - is_compiling: tl.constexpr = False, - attend_current: tl.constexpr = False, -): - # Loading thread information - block_start_offset = BLOCK_M * seq_prog_id - M_blk_idxs = block_start_offset + M_range - M_mask = M_blk_idxs < seq_length - NO_M_MASK = (block_start_offset + BLOCK_M - 1) < seq_length - - N_blk_idxs_start = 0 - N_blk_idxs = N_blk_idxs_start + N_range - - # Init pointers - # Inputs - DO_blk_ptrs = DO_head_seq_ptr + \ - (stride_dom * M_blk_idxs[:, None] + stride_dod * D_range[None, :]) - - K_blk_ptrs = K_head_seq_ptr + \ - (stride_kn * N_blk_idxs[:, None] + stride_kd * D_range[None, :]) - Q_blk_ptrs = Q_head_seq_ptr + \ - (stride_qm * M_blk_idxs[:, None] + stride_qd * D_range[None, :]) - V_blk_ptrs = V_head_seq_ptr + \ - (stride_vn * N_blk_idxs[:, None] + stride_vd * D_range[None, :]) - A_blk_ptrs = A_head_seq_ptr + stride_am * M_blk_idxs - # Outputs - DQ_blk_ptrs = DQ_head_seq_ptr + \ - (stride_dqm * M_blk_idxs[:, None] + stride_dqd * D_range[None, :]) - DK_blk_ptrs = DK_head_seq_ptr + \ - (stride_dkn * N_blk_idxs[:, None] + stride_dkd * D_range[None, :]) - DV_blk_ptrs = DV_head_seq_ptr + \ - (stride_dvn * N_blk_idxs[:, None] + stride_dvd * D_range[None, :]) - DR_blk_ptrs = DR_head_seq_ptr + stride_drm * M_blk_idxs - - # --- Load band vectors --- - if NO_D_MASK: - if NO_M_MASK: - q = tl.load(Q_blk_ptrs) - do = tl.load(DO_blk_ptrs) - dr = tl.load(DR_blk_ptrs) - neg_log_acc = tl.load(A_blk_ptrs, mask=M_mask) - else: - q = tl.load(Q_blk_ptrs, mask=M_mask[:, None]) - do = tl.load(DO_blk_ptrs, mask=M_mask[:, None]) - dr = tl.load(DR_blk_ptrs, mask=M_mask) - neg_log_acc = tl.load(A_blk_ptrs, mask=M_mask) - else: - MD_mask = M_mask[:, None] & D_mask[None, :] - q = tl.load(Q_blk_ptrs, mask=MD_mask) - do = tl.load(DO_blk_ptrs, mask=MD_mask) - dr = tl.load(DR_blk_ptrs, mask=M_mask) - neg_log_acc = tl.load(A_blk_ptrs, mask=M_mask) - # --- End band vectors --- - - # Init accumulators - neg_log_acc = neg_log_acc.to(dtype=acc_dtype) - grad_prev_acc = tl.zeros((BLOCK_M,), dtype=acc_dtype) - dq = tl.zeros((BLOCK_M, BLOCK_D), dtype=acc_dtype) - - fwd_cm = tl.trans(cm) - # always multiple of number of blocks. - iters = (block_start_offset + BLOCK_M) // BLOCK_N - # if (last_N_blk_idxs_end - sequence_start_offset) % BLOCK_N > 0: - # tl.device_print('remainder') - # Iterate only up to start of sequence - for i in range(iters): - on_band = (iters - i - 1) < BLOCK_M // BLOCK_N - N_mask = N_blk_idxs < seq_length - NO_N_MASK = (N_blk_idxs_start + BLOCK_N - 1) < seq_length - # --- Recompute block --- - k, v = load_kv( - K_blk_ptrs, - V_blk_ptrs, - N_mask=N_mask, - NO_N_MASK=(N_blk_idxs_start + BLOCK_N - 1) < seq_length, - # N_mask=N_mask, NO_N_MASK=False, - D_mask=D_mask, - NO_D_MASK=NO_D_MASK, - ) - p, log_om_beta, neg_log_acc = compute_block( - q, - k, - qk_scale, - neg_log_acc, - M_blk_idxs, - N_blk_idxs, - cm, - on_band, - ALLOW_TF32, - attend_current=attend_current, - backward=True, - is_compiling=is_compiling, - ) - - if not NO_M_MASK: - neg_log_acc = tl.where(M_mask, neg_log_acc, 0.0) - - # --- Do gradient stuff --- - att_dA = p * \ - (tl.dot(do, tl.trans(v), allow_tf32=ALLOW_TF32) - dr[:, None]) - cumul_att_dA = ( - tl.dot(att_dA.to(cm.dtype), fwd_cm, - allow_tf32=ALLOW_TF32) + grad_prev_acc[:, None] - ) # 180 -> 174 - # cumul_att_dA = tl.cumsum(att_dA, axis=1) + grad_prev_acc[:, None] # 180 -> 174 - grad_prev_acc += tl.sum(att_dA, axis=1) - beta = 1 - tl.exp2(log_om_beta) # 180 -> 175 - dqk = att_dA - beta * cumul_att_dA - - dq = tl.dot(dqk.to(k.dtype), k, acc=dq, allow_tf32=ALLOW_TF32) - block_dk = tl.dot(tl.trans(dqk).to(q.dtype), q, - allow_tf32=ALLOW_TF32) * logit_scale - block_dv = tl.dot(tl.trans(p), do.to(p.dtype), allow_tf32=ALLOW_TF32) - - locked_add( - KV_Lock_ptr + i, - KV_Count_ptr + i, - DK_blk_ptrs, - block_dk, - DV_blk_ptrs, - block_dv, - N_mask, - NO_N_MASK, - D_mask, - NO_D_MASK, - ) - - # --- End gradient stuff --- - N_blk_idxs += BLOCK_N - N_blk_idxs_start += BLOCK_N - K_blk_ptrs += BLOCK_N * stride_kn - V_blk_ptrs += BLOCK_N * stride_vn - DK_blk_ptrs += BLOCK_N * stride_dkn - DV_blk_ptrs += BLOCK_N * stride_dvn - - dq = (logit_scale * dq).to(DQ_head_seq_ptr.type.element_ty) - - if NO_D_MASK: - tl.store(DQ_blk_ptrs, dq, mask=M_mask[:, None]) - else: - tl.store(DQ_blk_ptrs, dq, mask=M_mask[:, None] & D_mask[None, :]) - - -def varlen_bwd( - do: torch.Tensor, - dr: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens: torch.Tensor, - max_seqlens: int, - neg_log_acc: torch.Tensor, - logit_scale, - attend_current=False, - BLOCK_M=64, - BLOCK_N=32, -): - batch_size = cu_seqlens.size(0) - num_heads, token_size, dim_size = q.size() - if logit_scale is None: - logit_scale = 1 / math.sqrt(dim_size) - N_count = triton.cdiv(token_size, BLOCK_N) - - # dqdkdv = torch.zeros((token_size, num_heads, 3 * dim_size), device=do.device, dtype=do.dtype) - # dqdkdv = dqdkdv.permute(1, 0, 2) - # dq, dk, dv = dqdkdv.chunk(3, dim=-1) - dq = torch.zeros_like(q) - dk = torch.zeros_like(k) - dv = torch.zeros_like(v) - - num_sequences = batch_size - num_folded_heads = triton.cdiv(num_heads, 2) - num_seq_blocks = triton.cdiv(max_seqlens, BLOCK_M) + 1 - _compileable_backward( - do, - dr, - q, - k, - v, - cu_seqlens, - neg_log_acc, - logit_scale, - BLOCK_M, - BLOCK_N, - batch_size, - num_heads, - token_size, - dim_size, - dq, - dk, - dv, - # dkdv_lock, - # dkdv_count, - num_sequences, - num_folded_heads, - num_seq_blocks, - attend_current=attend_current - ) - return dq, dk, dv - - -@custom_op("varlen_bwd", mutates_args={"dq", "dk", "dv"}) -def _compileable_backward( - do: torch.Tensor, - dr: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens: torch.Tensor, - neg_log_acc: torch.Tensor, - logit_scale: float, - BLOCK_M: int, - BLOCK_N: int, - batch_size: int, - num_heads: int, - token_size: int, - dim_size: int, - dq: torch.Tensor, - dk: torch.Tensor, - dv: torch.Tensor, - # dkdv_lock: torch.Tensor, - # dkdv_count: torch.Tensor, - num_sequences: int, - num_folded_heads: int, - num_seq_blocks: int, - attend_current: bool = False, -) -> None: - BLOCK_D = triton.next_power_of_2(dim_size) - N_count = num_seq_blocks * (BLOCK_M // BLOCK_N) - dkdv_lock = torch.zeros( - (num_sequences, num_heads, N_count), dtype=torch.int32, device=q.device) - dkdv_count = torch.zeros( - (num_sequences, num_heads, N_count), dtype=torch.int32, device=q.device) - - _backward[num_sequences, num_folded_heads, num_seq_blocks]( - # DO_ptr, stride_doh, stride_dom, stride_dod, - do, - do.stride(0), - do.stride(1), - do.stride(2), - # DR_ptr, stride_drh, stride_drm, - dr, - dr.stride(0), - dr.stride(1), - # A_ptr, stride_ah, stride_am, - neg_log_acc, - neg_log_acc.stride(0), - neg_log_acc.stride(1), - # Q_ptr, stride_qh, stride_qm, stride_qd, - q, - q.stride(0), - q.stride(1), - q.stride(2), - # K_ptr, stride_kh, stride_kn, stride_kd, - k, - k.stride(0), - k.stride(1), - k.stride(2), - # V_ptr, stride_vh, stride_vn, stride_vd, - v, - v.stride(0), - v.stride(1), - v.stride(2), - # DQ_ptr, stride_dqh, stride_dqm, stride_dqd, - dq, - dq.stride(0), - dq.stride(1), - dq.stride(2), - # DK_ptr, stride_dkh, stride_dkn, stride_dkd, - dk, - dk.stride(0), - dk.stride(1), - dk.stride(2), - # DV_ptr, stride_dvh, stride_dvn, stride_dvd, - dv, - dv.stride(0), - dv.stride(1), - dv.stride(2), - # KV_Lock_ptr, KV_Count_ptr, stride_kvl, - dkdv_lock, - dkdv_count, - dkdv_lock.stride(0), - dkdv_lock.stride(1), - cu_seqlens, - logit_scale=logit_scale, - batch_size=batch_size, - token_size=token_size, - head_size=dim_size, - num_heads=num_heads, - # BLOCK_M=BLOCK_M, - # BLOCK_N=BLOCK_N, - BLOCK_D=BLOCK_D, - BLOCK_CSL=triton.next_power_of_2(batch_size), - NO_D_MASK=BLOCK_D == dim_size, - NO_M_MASK=False, - NO_N_MASK=False, - ALLOW_TF32=ALLOW_TF32, - inv_log2=inv_log2, - attend_current=attend_current, - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N - ) diff --git a/sba_code/stickbreaking_attention/sb_varlen/sb_varlen_fwd.py b/sba_code/stickbreaking_attention/sb_varlen/sb_varlen_fwd.py deleted file mode 100644 index eceb7b78b..000000000 --- a/sba_code/stickbreaking_attention/sb_varlen/sb_varlen_fwd.py +++ /dev/null @@ -1,522 +0,0 @@ -import torch -import triton -import triton.language as tl - -from ..utils import ALLOW_TF32, inv_log2 -from .softplus import softplus -from ..utils import custom_op - - -@triton.jit -def load_kv(K_blk_ptrs, V_blk_ptrs, N_mask, NO_N_MASK, D_mask, NO_D_MASK: tl.constexpr): - if NO_D_MASK: - if NO_N_MASK: - k = tl.load(K_blk_ptrs) - v = tl.load(V_blk_ptrs) - else: - k = tl.load(K_blk_ptrs, mask=N_mask[:, None]) - v = tl.load(V_blk_ptrs, mask=N_mask[:, None]) - else: - mask = N_mask[:, None] & D_mask[None, :] - k = tl.load(K_blk_ptrs, mask=mask) - v = tl.load(V_blk_ptrs, mask=mask) - return k, v - - -@triton.jit -def compute_block( - q, - k, - qk_scale, - neg_log_acc, - M_blk_idxs, - N_blk_idxs, - cm, - on_band: tl.constexpr, - ALLOW_TF32: tl.constexpr, - backward: tl.constexpr, - attend_current: tl.constexpr = False, - use_cumsum: tl.constexpr = False, - is_compiling: tl.constexpr = False, -): - qk = tl.dot(q, tl.trans(k), allow_tf32=ALLOW_TF32) * qk_scale - - # log_om_beta (one minus beta) : log(1 - \beta) - log_om_beta = -softplus(qk, is_compiling=is_compiling) - - if on_band: # diagonal - if attend_current: - block_mask = M_blk_idxs[:, None] >= N_blk_idxs[None, :] - else: - block_mask = M_blk_idxs[:, None] > N_blk_idxs[None, :] - log_om_beta = tl.where(block_mask, log_om_beta, 0.0) - if backward: - neg_log_acc -= tl.sum(log_om_beta, axis=1) - log_p = qk + neg_log_acc[:, None] - - if use_cumsum: - log_p += tl.cumsum(log_om_beta.to(q.dtype), axis=1, reverse=True) - else: - log_p = tl.dot(log_om_beta.to(q.dtype), cm, - acc=log_p, allow_tf32=ALLOW_TF32) - - p = tl.math.exp2(log_p) - p = tl.where(block_mask, p, 0.0) - else: - if backward: - neg_log_acc -= tl.sum(log_om_beta, axis=1) - log_p = qk + neg_log_acc[:, None] - if use_cumsum: - log_p += tl.cumsum(log_om_beta.to(q.dtype), axis=1, reverse=True) - else: - log_p = tl.dot(log_om_beta.to(q.dtype), cm, - acc=log_p, allow_tf32=ALLOW_TF32) - - p = tl.math.exp2(log_p) - if not backward: - neg_log_acc += tl.sum(log_om_beta, axis=1) - return p, log_om_beta, neg_log_acc - - -@triton.jit -def _forward_one_row( - seq_block_id, - seq_length, - qk_scale, - M_range, - N_range, - D_range, - D_mask, - cm, - Q_head_seq_ptr, - stride_qm, - stride_qd: tl.constexpr, - K_head_seq_ptr, - stride_kn, - stride_kd: tl.constexpr, - V_head_seq_ptr, - stride_vn, - stride_vd: tl.constexpr, - O_head_seq_ptr, - stride_om, - stride_od: tl.constexpr, - R_head_seq_ptr, - stride_rm, - A_head_seq_ptr, - stride_am, - W_head_seq_ptr, - stride_wm, - stride_wn, - BLOCK_D: tl.constexpr, - NO_D_MASK: tl.constexpr, - NO_M_MASK: tl.constexpr, - NO_N_MASK: tl.constexpr, - ALLOW_TF32: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - no_grad: tl.constexpr = False, - acc_dtype: tl.constexpr = tl.float32, - return_attention: tl.constexpr = False, - is_compiling: tl.constexpr = False, - use_cumsum: tl.constexpr = False, - attend_current: tl.constexpr = False, -): - # Loading thread information - block_start_offset = BLOCK_M * seq_block_id - M_blk_idxs = block_start_offset + M_range - M_mask = M_blk_idxs < seq_length - NO_M_MASK = (block_start_offset + BLOCK_M - 1) < seq_length - - # BLOCK_M must be a multiple of BLOCK_N - N_blk_idxs_start = block_start_offset + BLOCK_M - N_blk_idxs = N_blk_idxs_start + N_range - - # Init pointers - Q_blk_ptrs = Q_head_seq_ptr + \ - (stride_qm * M_blk_idxs[:, None] + stride_qd * D_range[None, :]) - K_blk_ptrs = K_head_seq_ptr + \ - (stride_kn * N_blk_idxs[:, None] + stride_kd * D_range[None, :]) - V_blk_ptrs = V_head_seq_ptr + \ - (stride_vn * N_blk_idxs[:, None] + stride_vd * D_range[None, :]) - O_blk_ptrs = O_head_seq_ptr + \ - (stride_om * M_blk_idxs[:, None] + stride_od * D_range[None, :]) - R_blk_ptrs = R_head_seq_ptr + stride_rm * M_blk_idxs - A_blk_ptrs = A_head_seq_ptr + stride_am * M_blk_idxs - - # --- Load band vectors --- - if NO_D_MASK: - if NO_M_MASK: - q = tl.load(Q_blk_ptrs) - else: - q = tl.load(Q_blk_ptrs, mask=M_mask[:, None], other=0.0) - else: - q = tl.load( - Q_blk_ptrs, mask=M_mask[:, None] & D_mask[None, :], other=0.0) - - iters = N_blk_idxs_start // BLOCK_N - neg_log_acc = tl.zeros([BLOCK_M], dtype=acc_dtype) - acc = tl.zeros([BLOCK_M, BLOCK_D], dtype=acc_dtype) - # --- End band vectors --- - - # Iterate only up to start of sequence - for i in range(iters): - N_blk_idxs -= BLOCK_N - N_blk_idxs_start -= BLOCK_N - K_blk_ptrs -= BLOCK_N * stride_kn - V_blk_ptrs -= BLOCK_N * stride_vn - - N_mask = N_blk_idxs < seq_length - k, v = load_kv( - K_blk_ptrs, - V_blk_ptrs, - N_mask=N_mask, - NO_N_MASK=N_blk_idxs_start + BLOCK_N - 1 < seq_length, - D_mask=D_mask, - NO_D_MASK=NO_D_MASK, - ) - on_band = i < BLOCK_M // BLOCK_N - p, _, neg_log_acc = compute_block( - q, - k, - qk_scale, - neg_log_acc, - M_blk_idxs, - N_blk_idxs, - cm, - on_band, - ALLOW_TF32, - attend_current=attend_current, - backward=False, - is_compiling=is_compiling, - use_cumsum=use_cumsum, - ) - # Store intermediate values - acc = tl.dot(p.to(v.dtype), v, acc, allow_tf32=ALLOW_TF32) - if return_attention: # TODO write returns_attention_weight - tl.store( - W_head_seq_ptr + stride_wm * - M_blk_idxs[:, None] + stride_wn * N_blk_idxs[None, :], - p, - mask=(M_blk_idxs < seq_length)[:, None] & ( - N_blk_idxs < seq_length)[None, :], - ) - if NO_M_MASK: - tl.store(R_blk_ptrs, tl.math.exp2(neg_log_acc)) - tl.store(A_blk_ptrs, neg_log_acc.to(A_head_seq_ptr.type.element_ty)) - else: - tl.store(R_blk_ptrs, tl.math.exp2(neg_log_acc), mask=M_mask) - tl.store(A_blk_ptrs, neg_log_acc.to( - A_head_seq_ptr.type.element_ty), mask=M_mask) - if NO_D_MASK: - tl.store(O_blk_ptrs, acc.to( - O_head_seq_ptr.type.element_ty), mask=M_mask[:, None]) - else: - tl.store(O_blk_ptrs, acc.to(O_head_seq_ptr.type.element_ty), - mask=M_mask[:, None] & D_mask[None, :]) - - -def get_configs(): - return [triton.Config({}, num_stages=s, num_warps=w) - # for mb in [64, 128] - # for nb in [16, 32, 64] - for s in [4] # , 2, 3, 5, 6, 7, 8] - for w in [4]] # , 2]] - # for mb in [64] - # for nb in [32] - # for s in [4] - # for w in [4]] - - - -@triton.autotune(configs=get_configs(), key=["head_size"]) -@triton.jit -def _forward( - Q_ptr, - stride_qh: tl.constexpr, - stride_qm, - stride_qd: tl.constexpr, - K_ptr, - stride_kh: tl.constexpr, - stride_kn, - stride_kd: tl.constexpr, - V_ptr, - stride_vh: tl.constexpr, - stride_vn, - stride_vd: tl.constexpr, - O_ptr, - stride_oh: tl.constexpr, - stride_om, - stride_od: tl.constexpr, - R_ptr, - stride_rh, - stride_rm: tl.constexpr, - A_ptr, - stride_ah, - stride_am: tl.constexpr, - W_ptr, - stride_wh, - stride_wm, - stride_wn, - CSL_ptr, - logit_scale: tl.constexpr, - batch_size, - token_size, - head_size: tl.constexpr, - num_heads: tl.constexpr, - BLOCK_D: tl.constexpr, - NO_D_MASK: tl.constexpr, - NO_M_MASK: tl.constexpr, - NO_N_MASK: tl.constexpr, - ALLOW_TF32: tl.constexpr, - inv_log2: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - no_grad: tl.constexpr = False, - acc_dtype: tl.constexpr = tl.float32, - return_attention: tl.constexpr = False, - use_cumsum: tl.constexpr = False, - attend_current: tl.constexpr = False -): - tl.static_assert(BLOCK_M % BLOCK_N == 0) - seq_id = tl.program_id(0) - fhead_id = tl.program_id(1) - seq_alloc_prog_id = tl.program_id(2) - num_seq_alloc_progs = tl.num_programs(2) - if seq_id == 0: - seq_start_offset = 0 - else: - seq_start_offset = tl.load(CSL_ptr + seq_id - 1).to(tl.int32) - seq_end_offset = tl.load(CSL_ptr + seq_id).to(tl.int32) - seq_length = seq_end_offset - seq_start_offset - num_seq_blocks = tl.cdiv(seq_length, BLOCK_M) - - seq_a_block_id = num_seq_blocks - seq_alloc_prog_id - 1 - seq_b_block_id = seq_alloc_prog_id - (num_seq_alloc_progs - num_seq_blocks) - - if seq_a_block_id >= 0 or seq_b_block_id >= 0: - # Universal stuff - qk_scale = inv_log2 * logit_scale - M_range = tl.arange(0, BLOCK_M) - N_range = tl.arange(0, BLOCK_N) - D_range = tl.arange(0, BLOCK_D) - D_mask = D_range < head_size - if not use_cumsum: - cm = tl.where(N_range[:, None] >= N_range[None, :], 1.0, 0.0).to( - Q_ptr.type.element_ty) - else: - cm = None - - if seq_a_block_id >= 0: - # First head block - head_id = fhead_id * 2 - Q_head_seq_ptr = Q_ptr + stride_qh * head_id + stride_qm * seq_start_offset - K_head_seq_ptr = K_ptr + stride_kh * head_id + stride_kn * seq_start_offset - V_head_seq_ptr = V_ptr + stride_vh * head_id + stride_vn * seq_start_offset - O_head_seq_ptr = O_ptr + stride_oh * head_id + stride_om * seq_start_offset - R_head_seq_ptr = R_ptr + stride_rh * head_id + stride_rm * seq_start_offset - A_head_seq_ptr = A_ptr + stride_ah * head_id + stride_am * seq_start_offset - W_head_seq_ptr = W_ptr + stride_wh * head_id + stride_am * seq_start_offset - _forward_one_row( - seq_a_block_id, - seq_length, - qk_scale, - M_range, - N_range, - D_range, - D_mask, - cm, - Q_head_seq_ptr, - stride_qm, - stride_qd, - K_head_seq_ptr, - stride_kn, - stride_kd, - V_head_seq_ptr, - stride_vn, - stride_vd, - O_head_seq_ptr, - stride_om, - stride_od, - R_head_seq_ptr, - stride_rm, - A_head_seq_ptr, - stride_am, - W_head_seq_ptr, - stride_wm, - stride_wn, - BLOCK_D, - NO_D_MASK, - NO_M_MASK, - NO_N_MASK, - ALLOW_TF32, - BLOCK_M, - BLOCK_N, - no_grad, - acc_dtype, - return_attention, - use_cumsum=use_cumsum, - attend_current=attend_current - ) - if seq_b_block_id >= 0 and fhead_id * 2 + 1 < num_heads: - # Reverse head block - head_id = fhead_id * 2 + 1 - Q_head_seq_ptr = Q_ptr + stride_qh * head_id + stride_qm * seq_start_offset - K_head_seq_ptr = K_ptr + stride_kh * head_id + stride_kn * seq_start_offset - V_head_seq_ptr = V_ptr + stride_vh * head_id + stride_vn * seq_start_offset - O_head_seq_ptr = O_ptr + stride_oh * head_id + stride_om * seq_start_offset - R_head_seq_ptr = R_ptr + stride_rh * head_id + stride_rm * seq_start_offset - A_head_seq_ptr = A_ptr + stride_ah * head_id + stride_am * seq_start_offset - W_head_seq_ptr = W_ptr + stride_wh * head_id + stride_am * seq_start_offset - _forward_one_row( - seq_b_block_id, - seq_length, - qk_scale, - M_range, - N_range, - D_range, - D_mask, - cm, - Q_head_seq_ptr, - stride_qm, - stride_qd, - K_head_seq_ptr, - stride_kn, - stride_kd, - V_head_seq_ptr, - stride_vn, - stride_vd, - O_head_seq_ptr, - stride_om, - stride_od, - R_head_seq_ptr, - stride_rm, - A_head_seq_ptr, - stride_am, - W_head_seq_ptr, - stride_wm, - stride_wn, - BLOCK_D, - NO_D_MASK, - NO_M_MASK, - NO_N_MASK, - ALLOW_TF32, - BLOCK_M, - BLOCK_N, - no_grad, - acc_dtype, - return_attention, - use_cumsum=use_cumsum, - attend_current=attend_current - ) - - -def varlen_fwd( - q, k, v, cu_seqlens, max_seqlens, logit_scale, attend_current=False, no_grad=False, return_attention=False, BLOCK_M=64, BLOCK_N=32 -): - batch_size = cu_seqlens.size(0) - num_heads, token_size, dim_size = q.size() - o = torch.empty_like(q) - rem = torch.zeros_like(q[:, :, 0], device=q.device) - neg_log_acc = torch.zeros_like(rem, device=q.device, dtype=torch.float32) - if return_attention: - W = torch.full((num_heads, token_size, token_size), 0.0, - dtype=torch.float32, device=q.device) - else: - W = torch.empty((1, 1, 1), device=q.device) - - _compileable_forward( - q, - k, - v, - cu_seqlens, - max_seqlens, - logit_scale, - no_grad, - return_attention, - BLOCK_M, - BLOCK_N, - num_heads, - batch_size, - token_size, - dim_size, - o, - rem, - neg_log_acc, - W, - attend_current=attend_current - ) - if return_attention: - return o, rem, neg_log_acc, W - else: - return o, rem, neg_log_acc - - -@custom_op("varlen_fwd", mutates_args={"o", "rem", "neg_log_acc", "W"}) -def _compileable_forward( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens: torch.Tensor, - max_seqlens: int, - logit_scale: float, - no_grad: bool, - return_attention: bool, - BLOCK_M: int, - BLOCK_N: int, - num_heads: int, - batch_size: int, - token_size: int, - dim_size: int, - o: torch.Tensor, - rem: torch.Tensor, - neg_log_acc: torch.Tensor, - W: torch.Tensor, - attend_current: bool, -) -> None: - num_sequences = batch_size - num_folded_heads = triton.cdiv(num_heads, 2) - num_seq_blocks = triton.cdiv(max_seqlens, BLOCK_M) + 1 - BLOCK_D = triton.next_power_of_2(dim_size) - grid = (num_sequences, num_folded_heads, num_seq_blocks) - q_stride = q.stride() - k_stride = k.stride() - v_stride = v.stride() - o_stride = o.stride() - - _forward[grid]( - q, q_stride[0], q_stride[1], q_stride[2], - k, k_stride[0], k_stride[1], k_stride[2], - v, v_stride[0], v_stride[1], v_stride[2], - o, o_stride[0], o_stride[1], o_stride[2], - rem, - rem.stride(0), - rem.stride(1), - neg_log_acc, - neg_log_acc.stride(0), - neg_log_acc.stride(1), - W, - W.stride(0), - W.stride(1), - W.stride(2), - cu_seqlens, - # pid_debug, - logit_scale=logit_scale, - batch_size=batch_size, - token_size=token_size, - head_size=dim_size, - num_heads=num_heads, - no_grad=no_grad, - BLOCK_D=BLOCK_D, - NO_D_MASK=BLOCK_D == dim_size, - NO_M_MASK=False, - NO_N_MASK=False, - # BLOCK_M=BLOCK_M, - # BLOCK_N=BLOCK_N, - ALLOW_TF32=ALLOW_TF32, - inv_log2=inv_log2, - return_attention=return_attention, - acc_dtype=tl.float32, - use_cumsum=False, - attend_current=attend_current, - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N - ) diff --git a/sba_code/stickbreaking_attention/sb_varlen/softplus.py b/sba_code/stickbreaking_attention/sb_varlen/softplus.py deleted file mode 100644 index 50e60a11a..000000000 --- a/sba_code/stickbreaking_attention/sb_varlen/softplus.py +++ /dev/null @@ -1,52 +0,0 @@ -import torch -import triton -from triton import language as tl - - -def _generate_asm(num_pack): - template = """ - .reg .pred p; - setp.gt.f32 p, ${in_reg}, 15.; - @p mov.f32 ${out_reg}, ${in_reg}; - @!p ex2.approx.ftz.f32 ${out_reg}, ${in_reg}; - @!p add.f32 ${out_reg}, ${out_reg}, 1.0; - @!p lg2.approx.ftz.f32 ${out_reg}, ${out_reg}; - """ - out_str = "" - - for i in range(num_pack): - inner_str = template.format(out_reg=i, in_reg=i + num_pack) - out_str += "{" + inner_str + "}\n" - # flatten out because torch.compile doesn't like newlines - out_str = " ".join(out_str.split("\n")) - return out_str - - -def _generate_constraints(num_pack): - return ",".join("=r" for i in range(num_pack)) + "," + ",".join("r" for i in range(num_pack)) - - -_NUM_REG = 1 -asm_str: tl.constexpr = tl.constexpr(_generate_asm(_NUM_REG)) -constraints_str: tl.constexpr = tl.constexpr(_generate_constraints(_NUM_REG)) -NUM_REG: tl.constexpr = tl.constexpr(_NUM_REG) - - -@triton.jit -def softplus(x, is_compiling: tl.constexpr = False): - if is_compiling: - tl.static_print("Using triton softplus.") - out = tl.where(x < 15.0, tl.math.log2(1 + tl.math.exp2(x)), x) - return out - else: - out = tl.inline_asm_elementwise( - asm=asm_str, - constraints=constraints_str, - pack=NUM_REG, - args=[ - x, - ], - dtype=tl.float32, - is_pure=True, - ) - return out diff --git a/sba_code/stickbreaking_attention/utils.py b/sba_code/stickbreaking_attention/utils.py deleted file mode 100644 index 4810b2542..000000000 --- a/sba_code/stickbreaking_attention/utils.py +++ /dev/null @@ -1,39 +0,0 @@ -import torch -from typing import Callable, Iterable, Sequence -import math - -PACKAGE_NAME = "stickbreaking_attention" -log2 = math.log(2) -inv_log2 = 1 / log2 -ALLOW_TF32 = True - - - -def _dispatch(func: Callable, compileable_fn: Callable, *args, **kwargs): - if torch.compiler.is_compiling(): - output = compileable_fn(*args, **kwargs) - else: - output = func(*args, **kwargs) - return output - - -def custom_op( - name: str = None, - mutates_args: str | Iterable[str] = None, - device_types: str | Sequence[str] | None = None, - schema: str | None = None, -) -> Callable: - compileable_name = f"{PACKAGE_NAME}::{name}" - - def _inner(func: Callable): - compileable_func = torch.library.custom_op( - compileable_name, func, mutates_args=mutates_args, device_types=device_types, schema=schema - ) - - def _run(*args, **kwargs): - return _dispatch(func, compileable_func, *args, **kwargs) - # _run.__signature__ = inspect.signature(func) - # _run.__name__ = func.__name__ - return _run - - return _inner diff --git a/sba_code/tests/__init__.py b/sba_code/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/sba_code/tests/test_attn.py b/sba_code/tests/test_attn.py deleted file mode 100644 index 41b448c8a..000000000 --- a/sba_code/tests/test_attn.py +++ /dev/null @@ -1,74 +0,0 @@ -import torch -import pytest -import math -from stickbreaking_attention.sb_attn import sb_attn -from transformers import set_seed -from stickbreaking_attention.sb_ref import stickbreaking -from .test_varlen import assert_close - - -def ref_fwd(q, k, v, length, attend_current=False): - cm = torch.ones(length, length).tril(-1).to(q) - if attend_current: - mask = torch.ones(length, length).triu(1).cuda().bool() - else: - mask = torch.ones(length, length).triu(0).cuda().bool() - o, rem = stickbreaking(q, k, v, mask, cm) - o = o + rem[..., None] * v - return o - -def ref_fwdbwd(do, q, k, v, length, attend_current=False): - q.requires_grad = True - k.requires_grad = True - v.requires_grad = True - output = ref_fwd(q, k, v, length, attend_current) - output.backward(do) - dq = q.grad - dk = k.grad - dv = v.grad - q.grad = None - k.grad = None - v.grad = None - return output, dq, dk, dv - - -class TestClass: - - @pytest.mark.parametrize('batch_size', [4, 2, 1]) - @pytest.mark.parametrize('num_heads', [24, 8, 4, 2, 1, 7]) - @pytest.mark.parametrize('head_dim', [64, 32, 16, 50]) - @pytest.mark.parametrize('length', [4096, 2048, 1024, 512, 256, 500]) - @pytest.mark.parametrize('dtype', [torch.bfloat16]) - @pytest.mark.parametrize('forward_only', [False]) - @pytest.mark.parametrize('attend_current', [False, True]) - def test_varlen(self, batch_size, num_heads, head_dim, attend_current, length, dtype, forward_only): - set_seed(1337) - torch.set_printoptions(linewidth=110, edgeitems=30) - device = torch.device('cuda:0') - input_dims = (batch_size, num_heads, length, head_dim) - v = 0.25 * torch.randn(input_dims, device=device, dtype=torch.float32) - q = 0.25 * (torch.randn(input_dims, device=device, dtype=torch.float32) + 1) - k = 0.25 * (torch.randn(input_dims, device=device, dtype=torch.float32) - 1) - print(q.max(), k.max(), v.max()) - q = q.to(dtype).requires_grad_() - k = k.to(dtype).requires_grad_() - v = v.to(dtype).requires_grad_() - do = torch.randn(input_dims, device=device, dtype=dtype) - - with torch.cuda.device(device): - o, rem= sb_attn( - q, k, v, - inv_temp=1 / math.sqrt(q.size(-1)), - attend_current=attend_current - ) - o = o + rem[..., None] * v - ref_out, ref_dq, ref_dk, ref_dv = ref_fwdbwd(do, q, k, v, length, - attend_current=attend_current) - eps = 0.05 - torch.cuda.synchronize() - assert_close("o", ref_out, o, eps) - if not forward_only: - dq, dk, dv = torch.autograd.grad(o, inputs=(q, k, v), grad_outputs=do) - assert_close("dq", ref_dq, dq, eps) - assert_close("dk", ref_dk, dk, eps) - assert_close("dv", ref_dv, dv, eps) diff --git a/sba_code/tests/test_varlen.py b/sba_code/tests/test_varlen.py deleted file mode 100644 index 12b5073e6..000000000 --- a/sba_code/tests/test_varlen.py +++ /dev/null @@ -1,110 +0,0 @@ -import torch -import pytest -import math -from torch.nn import functional as F -from stickbreaking_attention.sb_varlen import sb_attn_varlen -from transformers import set_seed -from stickbreaking_attention.sb_ref import stickbreaking - - -def ref_fwd(q, k, v, lengths, attend_current=False): - splits = list(lengths.cpu().numpy()) - max_len = max(splits) - cm = torch.ones(max_len, max_len).tril(-1).to(q) - mask = torch.ones(max_len, max_len).triu(0 if not attend_current else 1).cuda().bool() - outputs = [] - for q_chunk, k_chunk, v_chunk in zip(q.split(splits, 1), k.split(splits, 1), v.split(splits, 1)): - len = q_chunk.size(1) - o, rem = stickbreaking( - q_chunk[None, :], - k_chunk[None, :], - v_chunk[None, :], - mask[:len, :len], cm[:len, :len] - ) - - o = o + rem[..., None] * v_chunk[None] - outputs.append(o[0]) - return torch.cat(outputs, 1) - -def ref_bwd(do, q, k, v, lengths, attend_current=False): - q.requires_grad = True - k.requires_grad = True - v.requires_grad = True - output = ref_fwd(q, k, v, lengths, attend_current=attend_current) - output.backward(do) - dq = q.grad - dk = k.grad - dv = v.grad - q.grad = None - k.grad = None - v.grad = None - return output, dq, dk, dv - -def assert_close(varname, a, b, eps): - if torch.isnan(a).any(): - print("Reference is nan") - return - assert not torch.isnan(b).any() - diff = (a - b).abs() - - max_diff= diff.max() - if max_diff < eps: - print(varname, max_diff.item()) - else: - print(varname, max_diff.item(), diff.median().item()) - print((diff.sum(0).median(dim=0)[0] > eps).int()) - err_locs = (diff.sum(0).median(dim=1)[0] > eps).int() - print(err_locs, err_locs.sum()) - assert max_diff < eps, max_diff - - - -class TestClass: - - # @pytest.mark.parametrize('batch_size', [4, 2, 1]) - # @pytest.mark.parametrize('num_heads', [24, 8, 4, 2, 1, 7]) - # @pytest.mark.parametrize('head_dim', [64, 32, 16, 50]) - # @pytest.mark.parametrize('length', [4096, 2048, 1024, 512, 256, 500]) - @pytest.mark.parametrize('batch_size', [1]) - @pytest.mark.parametrize('num_heads', [12, 3]) - @pytest.mark.parametrize('head_dim', [128]) - @pytest.mark.parametrize('length', [4096, 8192, 8192 * 2]) - @pytest.mark.parametrize('dtype', [torch.bfloat16]) - @pytest.mark.parametrize('forward_only', [False]) - @pytest.mark.parametrize('attend_current', [False, True]) - def test_varlen(self, batch_size, num_heads, head_dim, length, attend_current, dtype, forward_only): - set_seed(1337) - torch.set_printoptions(linewidth=110, edgeitems=30) - device = torch.device('cuda:0') - lengths = torch.randint(length, length + 1, (batch_size,)).to(device=device, dtype=torch.int32) - print(lengths) - total_length = lengths.sum() - cu_seqlens = torch.cumsum(lengths, dim=-1) - v = 0.25 * torch.randn((num_heads, total_length, head_dim), device=device, dtype=torch.float32) - q = 0.25 * (torch.randn((num_heads, total_length, head_dim), device=device, dtype=torch.float32) + 1) - k = 0.25 * (torch.randn((num_heads, total_length, head_dim), device=device, dtype=torch.float32) - 1) - print(q.max(), k.max(), v.max()) - q = q.to(dtype) - k = k.to(dtype) - v = v.to(dtype) - q.requires_grad_() - k.requires_grad_() - v.requires_grad_() - do = torch.randn((num_heads, total_length, head_dim), device=device, dtype=dtype) - with torch.cuda.device(device): - o, rem = sb_attn_varlen(q, k, v, - cu_seqlens=cu_seqlens, - max_seqlens=torch.max(lengths).item(), - inv_temp=1 / math.sqrt(q.size(-1)), - zero_start=False, - attend_current=attend_current) - o = o + rem[..., None] * v - ref_out, ref_dq, ref_dk, ref_dv = ref_bwd(do, q, k, v, lengths, attend_current=attend_current) - eps = 0.05 - torch.cuda.synchronize() - assert_close("o", ref_out, o, eps) - if not forward_only: - dq, dk, dv = torch.autograd.grad(o, inputs=(q, k, v), grad_outputs=do) - assert_close("dq", ref_dq, dq, eps) - assert_close("dk", ref_dk, dk, eps) - assert_close("dv", ref_dv, dv, eps) From 1bf8ba2693d0be11d360bce6b02994a18533e486 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 10 Nov 2025 14:59:45 +0000 Subject: [PATCH 09/10] Minor fix --- fla/ops/stickbreaking_attn/parallel.py | 726 +++++++++++++------------ 1 file changed, 369 insertions(+), 357 deletions(-) diff --git a/fla/ops/stickbreaking_attn/parallel.py b/fla/ops/stickbreaking_attn/parallel.py index 0642aaa47..3daa87ac5 100644 --- a/fla/ops/stickbreaking_attn/parallel.py +++ b/fla/ops/stickbreaking_attn/parallel.py @@ -1,111 +1,78 @@ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang -import math - import torch import triton import triton.language as tl from fla.ops.utils.index import prepare_chunk_indices +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous ALLOW_TF32 = True -inv_log2 = 1.0 / math.log(2.0) -def _get_configs(): - return [triton.Config({}, num_stages=s, num_warps=w) for s in [4] for w in [4]] +@triton.jit +def softplus(x, is_compiling: tl.constexpr = False): + return tl.where(x < 15.0, tl.math.log2(1 + tl.math.exp2(x)), x) -@triton.autotune(configs=_get_configs(), key=["token_size", "head_size"]) @triton.jit -def stickbreaking_attn_fwd_kernel( - Q_ptr, - K_ptr, - V_ptr, - O_ptr, - R_ptr, - A_ptr, - CU_ptr, - CI_ptr, - scale: tl.constexpr, - batch_size, - token_size, - head_size: tl.constexpr, - num_heads: tl.constexpr, - BLOCK_D: tl.constexpr, - NO_D_MASK: tl.constexpr, - NO_M_MASK: tl.constexpr, - NO_N_MASK: tl.constexpr, +def load_kv(K_blk_ptrs, V_blk_ptrs, N_mask, NO_N_MASK, D_mask, NO_D_MASK: tl.constexpr): + if NO_D_MASK: + if NO_N_MASK: + k = tl.load(K_blk_ptrs) + v = tl.load(V_blk_ptrs) + else: + k = tl.load(K_blk_ptrs, mask=N_mask[:, None]) + v = tl.load(V_blk_ptrs, mask=N_mask[:, None]) + else: + mask = N_mask[:, None] & D_mask[None, :] + k = tl.load(K_blk_ptrs, mask=mask) + v = tl.load(V_blk_ptrs, mask=mask) + return k, v + + +@triton.jit +def compute_block( + q, + k, + qk_scale, + neg_log_acc, + M_blk_idxs, + N_blk_idxs, + cm, + on_band: tl.constexpr, ALLOW_TF32: tl.constexpr, - inv_log2: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - no_grad: tl.constexpr = False, - acc_dtype: tl.constexpr = tl.float32, + backward: tl.constexpr, + use_cumsum: tl.constexpr = False, is_compiling: tl.constexpr = False, - IS_VARLEN: tl.constexpr = False, ): - tl.static_assert(BLOCK_M % BLOCK_N == 0) - batch_id = 0 if IS_VARLEN else tl.program_id(0) - head_pid = tl.program_id(1) - prog_id = tl.program_id(2) - tl.num_programs(2) - if IS_VARLEN: - i_n = tl.load(CI_ptr + prog_id * 2).to(tl.int32) - seq_block_id = tl.load(CI_ptr + prog_id * 2 + 1).to(tl.int32) - bos = tl.load(CU_ptr + i_n).to(tl.int32) - eos = tl.load(CU_ptr + i_n + 1).to(tl.int32) - seq_length = eos - bos - else: - bos = tl.full([], 0, dtype=tl.int32) - seq_block_id = prog_id - seq_length = token_size - qk_scale = inv_log2 * scale - M_range = tl.arange(0, BLOCK_M) - N_range = tl.arange(0, BLOCK_N) - D_range = tl.arange(0, BLOCK_D) - D_mask = D_range < head_size - cm = tl.where(N_range[:, None] >= N_range[None, :], 1.0, 0.0).to(Q_ptr.type.element_ty) - - head_id = head_pid - seq_prog_id = seq_block_id - batch_offset = batch_id * token_size - Q_head_seq_ptr = Q_ptr + ((batch_offset + bos) * num_heads + head_id) * head_size - K_head_seq_ptr = K_ptr + ((batch_offset + bos) * num_heads + head_id) * head_size - V_head_seq_ptr = V_ptr + ((batch_offset + bos) * num_heads + head_id) * head_size - O_head_seq_ptr = O_ptr + ((batch_offset + bos) * num_heads + head_id) * head_size - R_head_seq_ptr = R_ptr + ((batch_offset + bos) * num_heads + head_id) - A_head_seq_ptr = A_ptr + ((batch_offset + bos) * num_heads + head_id) + qk = tl.dot(q, tl.trans(k), allow_tf32=ALLOW_TF32) * qk_scale + log_om_beta = -softplus(qk, is_compiling=is_compiling) - stickbreaking_attn_fwd_one_row_kernel( - seq_prog_id, - seq_length, - qk_scale, - M_range, - N_range, - D_range, - D_mask, - cm, - Q_head_seq_ptr, - K_head_seq_ptr, - V_head_seq_ptr, - O_head_seq_ptr, - R_head_seq_ptr, - A_head_seq_ptr, - head_size, - num_heads, - BLOCK_D, - NO_D_MASK, - NO_M_MASK, - NO_N_MASK, - ALLOW_TF32, - BLOCK_M, - BLOCK_N, - no_grad, - acc_dtype, - False, - is_compiling=is_compiling, - ) + if on_band: + block_mask = M_blk_idxs[:, None] > N_blk_idxs[None, :] + log_om_beta = tl.where(block_mask, log_om_beta, 0.0) + if backward: + neg_log_acc -= tl.sum(log_om_beta, axis=1) + log_p = qk + neg_log_acc[:, None] + if use_cumsum: + log_p += tl.cumsum(log_om_beta.to(q.dtype), axis=1, reverse=True) + else: + log_p = tl.dot(log_om_beta.to(q.dtype), cm, acc=log_p, allow_tf32=ALLOW_TF32) + p = tl.math.exp2(log_p) + p = tl.where(block_mask, p, 0.0) + else: + if backward: + neg_log_acc -= tl.sum(log_om_beta, axis=1) + log_p = qk + neg_log_acc[:, None] + if use_cumsum: + log_p += tl.cumsum(log_om_beta.to(q.dtype), axis=1, reverse=True) + else: + log_p = tl.dot(log_om_beta.to(q.dtype), cm, acc=log_p, allow_tf32=ALLOW_TF32) + p = tl.math.exp2(log_p) + if not backward: + neg_log_acc += tl.sum(log_om_beta, axis=1) + return p, log_om_beta, neg_log_acc @triton.jit @@ -125,39 +92,39 @@ def stickbreaking_attn_fwd_one_row_kernel( R_head_seq_ptr, A_head_seq_ptr, head_size: tl.constexpr, - num_heads: tl.constexpr, + H: tl.constexpr, BLOCK_D: tl.constexpr, NO_D_MASK: tl.constexpr, NO_M_MASK: tl.constexpr, NO_N_MASK: tl.constexpr, ALLOW_TF32: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, no_grad: tl.constexpr = False, acc_dtype: tl.constexpr = tl.float32, return_attention: tl.constexpr = False, is_compiling: tl.constexpr = False, ): - block_start_offset = BLOCK_M * seq_block_id + block_start_offset = BT * seq_block_id M_blk_idxs = block_start_offset + M_range M_mask = M_blk_idxs < seq_length - N_blk_idxs_start = block_start_offset + BLOCK_M + N_blk_idxs_start = block_start_offset + BT N_blk_idxs = N_blk_idxs_start + N_range Q_blk_ptrs = Q_head_seq_ptr + ( - (num_heads * head_size) * M_blk_idxs[:, None] + 1 * D_range[None, :] + (H * head_size) * M_blk_idxs[:, None] + 1 * D_range[None, :] ) K_blk_ptrs = K_head_seq_ptr + ( - (num_heads * head_size) * N_blk_idxs[:, None] + 1 * D_range[None, :] + (H * head_size) * N_blk_idxs[:, None] + 1 * D_range[None, :] ) V_blk_ptrs = V_head_seq_ptr + ( - (num_heads * head_size) * N_blk_idxs[:, None] + 1 * D_range[None, :] + (H * head_size) * N_blk_idxs[:, None] + 1 * D_range[None, :] ) O_blk_ptrs = O_head_seq_ptr + ( - (num_heads * head_size) * M_blk_idxs[:, None] + 1 * D_range[None, :] + (H * head_size) * M_blk_idxs[:, None] + 1 * D_range[None, :] ) - R_blk_ptrs = R_head_seq_ptr + num_heads * M_blk_idxs - A_blk_ptrs = A_head_seq_ptr + num_heads * M_blk_idxs + R_blk_ptrs = R_head_seq_ptr + H * M_blk_idxs + A_blk_ptrs = A_head_seq_ptr + H * M_blk_idxs if NO_D_MASK: if NO_M_MASK: @@ -167,26 +134,26 @@ def stickbreaking_attn_fwd_one_row_kernel( else: q = tl.load(Q_blk_ptrs, mask=M_mask[:, None] & D_mask[None, :], other=0.0) - iters = N_blk_idxs_start // BLOCK_N - neg_log_acc = tl.zeros([BLOCK_M], dtype=acc_dtype) - acc = tl.zeros([BLOCK_M, BLOCK_D], dtype=acc_dtype) + iters = N_blk_idxs_start // BS + neg_log_acc = tl.zeros([BT], dtype=acc_dtype) + acc = tl.zeros([BT, BLOCK_D], dtype=acc_dtype) for i in range(iters): - N_blk_idxs -= BLOCK_N - N_blk_idxs_start -= BLOCK_N - K_blk_ptrs -= BLOCK_N * (num_heads * head_size) - V_blk_ptrs -= BLOCK_N * (num_heads * head_size) + N_blk_idxs -= BS + N_blk_idxs_start -= BS + K_blk_ptrs -= BS * (H * head_size) + V_blk_ptrs -= BS * (H * head_size) N_mask = N_blk_idxs < seq_length k, v = load_kv( K_blk_ptrs, V_blk_ptrs, N_mask=N_mask, - NO_N_MASK=N_blk_idxs_start + BLOCK_N - 1 < seq_length, + NO_N_MASK=N_blk_idxs_start + BS - 1 < seq_length, D_mask=D_mask, NO_D_MASK=NO_D_MASK, ) - on_band = i < BLOCK_M // BLOCK_N + on_band = i < BT // BS p, _log_om_beta, neg_log_acc = compute_block( q, k, @@ -215,183 +182,6 @@ def stickbreaking_attn_fwd_one_row_kernel( tl.store(O_blk_ptrs, acc.to(O_head_seq_ptr.type.element_ty), mask=M_mask[:, None] & D_mask[None, :]) -def _get_bwd_configs(): - return [triton.Config({}, num_stages=s, num_warps=w) for s in [8] for w in [4]] - - -@triton.autotune(configs=_get_bwd_configs(), key=["token_size", "head_size"]) -@triton.jit() -def stickbreaking_attn_bwd_kernel( - DO_ptr, - DR_ptr, - A_ptr, - Q_ptr, - K_ptr, - V_ptr, - DQ_ptr, - DK_ptr, - DV_ptr, - CU_ptr, - CI_ptr, - scale, - batch_size, - token_size, - head_size: tl.constexpr, - num_heads: tl.constexpr, - BLOCK_D: tl.constexpr, - NO_D_MASK: tl.constexpr, - NO_M_MASK: tl.constexpr, - NO_N_MASK: tl.constexpr, - ALLOW_TF32: tl.constexpr, - inv_log2: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - acc_dtype: tl.constexpr = tl.float32, - is_compiling: tl.constexpr = False, - IS_VARLEN: tl.constexpr = False, -): - tl.static_assert(BLOCK_M % BLOCK_N == 0) - batch_id = 0 if IS_VARLEN else tl.program_id(0) - head_pid = tl.program_id(1) - prog_id = tl.program_id(2) - qk_scale = inv_log2 * scale - M_range = tl.arange(0, BLOCK_M) - N_range = tl.arange(0, BLOCK_N) - D_range = tl.arange(0, BLOCK_D) - D_mask = D_range < head_size - cm = tl.where(N_range[:, None] >= N_range[None, :], 1.0, 0.0).to(Q_ptr.type.element_ty) - - if IS_VARLEN: - i_n = tl.load(CI_ptr + prog_id * 2).to(tl.int32) - seq_block_id = tl.load(CI_ptr + prog_id * 2 + 1).to(tl.int32) - bos = tl.load(CU_ptr + i_n).to(tl.int32) - eos = tl.load(CU_ptr + i_n + 1).to(tl.int32) - seq_length = eos - bos - else: - bos = 0 - seq_block_id = prog_id - seq_length = token_size - - head_id = head_pid - seq_prog_id = seq_block_id - - batch_id_i64 = batch_id.to(tl.int64) - head_id_i64 = head_id.to(tl.int64) - seq_prog_id_i64 = seq_prog_id.to(tl.int64) - bos_i64 = bos.to(tl.int64) - - batch_offset = batch_id_i64 * token_size - head_offset = (batch_offset + bos_i64) * num_heads + head_id_i64 - block_offset = seq_prog_id_i64 * batch_size * token_size * num_heads - - DO_head_seq_ptr = DO_ptr + head_offset * head_size - DR_head_seq_ptr = DR_ptr + head_offset - A_head_seq_ptr = A_ptr + head_offset - Q_head_seq_ptr = Q_ptr + head_offset * head_size - K_head_seq_ptr = K_ptr + head_offset * head_size - V_head_seq_ptr = V_ptr + head_offset * head_size - DQ_head_seq_ptr = DQ_ptr + head_offset * head_size - DK_head_seq_ptr = DK_ptr + (block_offset + head_offset) * head_size - DV_head_seq_ptr = DV_ptr + (block_offset + head_offset) * head_size - - stickbreaking_attn_bwd_one_row_kernel( - seq_prog_id, - seq_length, - qk_scale, - M_range, - N_range, - D_range, - D_mask, - cm, - DO_head_seq_ptr, - DR_head_seq_ptr, - A_head_seq_ptr, - Q_head_seq_ptr, - K_head_seq_ptr, - V_head_seq_ptr, - DQ_head_seq_ptr, - DK_head_seq_ptr, - DV_head_seq_ptr, - scale, - head_size, - num_heads, - BLOCK_D, - NO_D_MASK, - NO_M_MASK, - NO_N_MASK, - ALLOW_TF32, - BLOCK_M, - BLOCK_N, - acc_dtype, - is_compiling=is_compiling, - ) - - -@triton.jit -def load_kv(K_blk_ptrs, V_blk_ptrs, N_mask, NO_N_MASK, D_mask, NO_D_MASK: tl.constexpr): - if NO_D_MASK: - if NO_N_MASK: - k = tl.load(K_blk_ptrs) - v = tl.load(V_blk_ptrs) - else: - k = tl.load(K_blk_ptrs, mask=N_mask[:, None]) - v = tl.load(V_blk_ptrs, mask=N_mask[:, None]) - else: - mask = N_mask[:, None] & D_mask[None, :] - k = tl.load(K_blk_ptrs, mask=mask) - v = tl.load(V_blk_ptrs, mask=mask) - return k, v - - -@triton.jit -def compute_block( - q, - k, - qk_scale, - neg_log_acc, - M_blk_idxs, - N_blk_idxs, - cm, - on_band: tl.constexpr, - ALLOW_TF32: tl.constexpr, - backward: tl.constexpr, - use_cumsum: tl.constexpr = False, - is_compiling: tl.constexpr = False, -): - qk = tl.dot(q, tl.trans(k), allow_tf32=ALLOW_TF32) * qk_scale - log_om_beta = -softplus(qk, is_compiling=is_compiling) - - if on_band: - block_mask = M_blk_idxs[:, None] > N_blk_idxs[None, :] - log_om_beta = tl.where(block_mask, log_om_beta, 0.0) - if backward: - neg_log_acc -= tl.sum(log_om_beta, axis=1) - log_p = qk + neg_log_acc[:, None] - if use_cumsum: - log_p += tl.cumsum(log_om_beta.to(q.dtype), axis=1, reverse=True) - else: - log_p = tl.dot(log_om_beta.to(q.dtype), cm, acc=log_p, allow_tf32=ALLOW_TF32) - p = tl.math.exp2(log_p) - p = tl.where(block_mask, p, 0.0) - else: - if backward: - neg_log_acc -= tl.sum(log_om_beta, axis=1) - log_p = qk + neg_log_acc[:, None] - if use_cumsum: - log_p += tl.cumsum(log_om_beta.to(q.dtype), axis=1, reverse=True) - else: - log_p = tl.dot(log_om_beta.to(q.dtype), cm, acc=log_p, allow_tf32=ALLOW_TF32) - p = tl.math.exp2(log_p) - if not backward: - neg_log_acc += tl.sum(log_om_beta, axis=1) - return p, log_om_beta, neg_log_acc - - -@triton.jit -def softplus(x, is_compiling: tl.constexpr = False): - return tl.where(x < 15.0, tl.math.log2(1 + tl.math.exp2(x)), x) - - @triton.jit def stickbreaking_attn_bwd_one_row_kernel( seq_prog_id, @@ -413,18 +203,18 @@ def stickbreaking_attn_bwd_one_row_kernel( DV_head_seq_ptr, scale, head_size: tl.constexpr, - num_heads: tl.constexpr, + H: tl.constexpr, BLOCK_D: tl.constexpr, NO_D_MASK: tl.constexpr, NO_M_MASK: tl.constexpr, NO_N_MASK: tl.constexpr, ALLOW_TF32: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, acc_dtype: tl.constexpr = tl.float32, is_compiling: tl.constexpr = False, ): - block_start_offset = BLOCK_M * seq_prog_id + block_start_offset = BT * seq_prog_id M_blk_idxs = block_start_offset + M_range M_mask = M_blk_idxs < seq_length @@ -432,28 +222,28 @@ def stickbreaking_attn_bwd_one_row_kernel( N_blk_idxs = N_blk_idxs_start + N_range DO_blk_ptrs = DO_head_seq_ptr + ( - (num_heads * head_size) * M_blk_idxs[:, None] + 1 * D_range[None, :] + (H * head_size) * M_blk_idxs[:, None] + 1 * D_range[None, :] ) K_blk_ptrs = K_head_seq_ptr + ( - (num_heads * head_size) * N_blk_idxs[:, None] + 1 * D_range[None, :] + (H * head_size) * N_blk_idxs[:, None] + 1 * D_range[None, :] ) Q_blk_ptrs = Q_head_seq_ptr + ( - (num_heads * head_size) * M_blk_idxs[:, None] + 1 * D_range[None, :] + (H * head_size) * M_blk_idxs[:, None] + 1 * D_range[None, :] ) V_blk_ptrs = V_head_seq_ptr + ( - (num_heads * head_size) * N_blk_idxs[:, None] + 1 * D_range[None, :] + (H * head_size) * N_blk_idxs[:, None] + 1 * D_range[None, :] ) - A_blk_ptrs = A_head_seq_ptr + num_heads * M_blk_idxs + A_blk_ptrs = A_head_seq_ptr + H * M_blk_idxs DQ_blk_ptrs = DQ_head_seq_ptr + ( - (num_heads * head_size) * M_blk_idxs[:, None] + 1 * D_range[None, :] + (H * head_size) * M_blk_idxs[:, None] + 1 * D_range[None, :] ) DK_blk_ptrs = DK_head_seq_ptr + ( - (num_heads * head_size) * N_blk_idxs[:, None] + 1 * D_range[None, :] + (H * head_size) * N_blk_idxs[:, None] + 1 * D_range[None, :] ) DV_blk_ptrs = DV_head_seq_ptr + ( - (num_heads * head_size) * N_blk_idxs[:, None] + 1 * D_range[None, :] + (H * head_size) * N_blk_idxs[:, None] + 1 * D_range[None, :] ) - DR_blk_ptrs = DR_head_seq_ptr + num_heads * M_blk_idxs + DR_blk_ptrs = DR_head_seq_ptr + H * M_blk_idxs if NO_D_MASK: if NO_N_MASK: @@ -474,15 +264,15 @@ def stickbreaking_attn_bwd_one_row_kernel( neg_log_acc = tl.load(A_blk_ptrs, mask=M_mask) neg_log_acc = neg_log_acc.to(dtype=acc_dtype) - grad_prev_acc = tl.zeros((BLOCK_M,), dtype=acc_dtype) - dq = tl.zeros((BLOCK_M, BLOCK_D), dtype=acc_dtype) + grad_prev_acc = tl.zeros((BT,), dtype=acc_dtype) + dq = tl.zeros((BT, BLOCK_D), dtype=acc_dtype) fwd_cm = tl.trans(cm) - iters = (block_start_offset + BLOCK_M) // BLOCK_N + iters = (block_start_offset + BT) // BS for i in range(iters): - on_band = (iters - i - 1) < BLOCK_M // BLOCK_N + on_band = (iters - i - 1) < BT // BS N_mask = N_blk_idxs < seq_length - local_no_n_mask = (N_blk_idxs_start + BLOCK_N - 1) < seq_length + local_no_n_mask = (N_blk_idxs_start + BS - 1) < seq_length k, v = load_kv( K_blk_ptrs, V_blk_ptrs, @@ -526,12 +316,12 @@ def stickbreaking_attn_bwd_one_row_kernel( tl.store(DK_blk_ptrs, block_dk, mask=mask) tl.store(DV_blk_ptrs, block_dv, mask=mask) - N_blk_idxs += BLOCK_N - N_blk_idxs_start += BLOCK_N - K_blk_ptrs += BLOCK_N * (num_heads * head_size) - V_blk_ptrs += BLOCK_N * (num_heads * head_size) - DK_blk_ptrs += BLOCK_N * (num_heads * head_size) - DV_blk_ptrs += BLOCK_N * (num_heads * head_size) + N_blk_idxs += BS + N_blk_idxs_start += BS + K_blk_ptrs += BS * (H * head_size) + V_blk_ptrs += BS * (H * head_size) + DK_blk_ptrs += BS * (H * head_size) + DV_blk_ptrs += BS * (H * head_size) dq = (scale * dq).to(DQ_head_seq_ptr.type.element_ty) @@ -541,7 +331,223 @@ def stickbreaking_attn_bwd_one_row_kernel( tl.store(DQ_blk_ptrs, dq, mask=M_mask[:, None] & D_mask[None, :]) -def stickbreaking_attn_fwd( +@triton.autotune( + configs=[ + triton.Config({}, num_stages=s, num_warps=w) + for s in [4] + for w in [4] + ], + key=["T", "head_size"] +) +@triton.jit +def parallel_stickbreaking_attn_fwd_kernel( + Q_ptr, + K_ptr, + V_ptr, + O_ptr, + R_ptr, + A_ptr, + CU_ptr, + CI_ptr, + scale: tl.constexpr, + B, + T, + head_size: tl.constexpr, + H: tl.constexpr, + BLOCK_D: tl.constexpr, + NO_D_MASK: tl.constexpr, + NO_M_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + no_grad: tl.constexpr = False, + acc_dtype: tl.constexpr = tl.float32, + is_compiling: tl.constexpr = False, + IS_VARLEN: tl.constexpr = False, +): + tl.static_assert(BT % BS == 0) + batch_id = 0 if IS_VARLEN else tl.program_id(0) + head_pid = tl.program_id(1) + prog_id = tl.program_id(2) + tl.num_programs(2) + if IS_VARLEN: + i_n = tl.load(CI_ptr + prog_id * 2).to(tl.int32) + seq_block_id = tl.load(CI_ptr + prog_id * 2 + 1).to(tl.int32) + bos = tl.load(CU_ptr + i_n).to(tl.int32) + eos = tl.load(CU_ptr + i_n + 1).to(tl.int32) + seq_length = eos - bos + else: + bos = tl.full([], 0, dtype=tl.int32) + seq_block_id = prog_id + seq_length = T + RCP_LN2: tl.constexpr = 1.4426950216 + + qk_scale = RCP_LN2 * scale + M_range = tl.arange(0, BT) + N_range = tl.arange(0, BS) + D_range = tl.arange(0, BLOCK_D) + D_mask = D_range < head_size + cm = tl.where(N_range[:, None] >= N_range[None, :], 1.0, 0.0).to(Q_ptr.type.element_ty) + + head_id = head_pid + seq_prog_id = seq_block_id + batch_offset = batch_id * T + Q_head_seq_ptr = Q_ptr + ((batch_offset + bos) * H + head_id) * head_size + K_head_seq_ptr = K_ptr + ((batch_offset + bos) * H + head_id) * head_size + V_head_seq_ptr = V_ptr + ((batch_offset + bos) * H + head_id) * head_size + O_head_seq_ptr = O_ptr + ((batch_offset + bos) * H + head_id) * head_size + R_head_seq_ptr = R_ptr + ((batch_offset + bos) * H + head_id) + A_head_seq_ptr = A_ptr + ((batch_offset + bos) * H + head_id) + + stickbreaking_attn_fwd_one_row_kernel( + seq_prog_id, + seq_length, + qk_scale, + M_range, + N_range, + D_range, + D_mask, + cm, + Q_head_seq_ptr, + K_head_seq_ptr, + V_head_seq_ptr, + O_head_seq_ptr, + R_head_seq_ptr, + A_head_seq_ptr, + head_size, + H, + BLOCK_D, + NO_D_MASK, + NO_M_MASK, + NO_N_MASK, + ALLOW_TF32, + BT, + BS, + no_grad, + acc_dtype, + False, + is_compiling=is_compiling, + ) + + +@triton.autotune( + configs=[ + triton.Config({}, num_stages=s, num_warps=w) + for s in [8] + for w in [4] + ], + key=["T", "head_size"] +) +@triton.jit() +def parallel_stickbreaking_attn_bwd_kernel( + DO_ptr, + DR_ptr, + A_ptr, + Q_ptr, + K_ptr, + V_ptr, + DQ_ptr, + DK_ptr, + DV_ptr, + CU_ptr, + CI_ptr, + scale, + B, + T, + head_size: tl.constexpr, + H: tl.constexpr, + BLOCK_D: tl.constexpr, + NO_D_MASK: tl.constexpr, + NO_M_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + acc_dtype: tl.constexpr = tl.float32, + is_compiling: tl.constexpr = False, + IS_VARLEN: tl.constexpr = False, +): + tl.static_assert(BT % BS == 0) + batch_id = 0 if IS_VARLEN else tl.program_id(0) + head_pid = tl.program_id(1) + prog_id = tl.program_id(2) + RCP_LN2: tl.constexpr = 1.4426950216 + + qk_scale = RCP_LN2 * scale + M_range = tl.arange(0, BT) + N_range = tl.arange(0, BS) + D_range = tl.arange(0, BLOCK_D) + D_mask = D_range < head_size + cm = tl.where(N_range[:, None] >= N_range[None, :], 1.0, 0.0).to(Q_ptr.type.element_ty) + + if IS_VARLEN: + i_n = tl.load(CI_ptr + prog_id * 2).to(tl.int32) + seq_block_id = tl.load(CI_ptr + prog_id * 2 + 1).to(tl.int32) + bos = tl.load(CU_ptr + i_n).to(tl.int32) + eos = tl.load(CU_ptr + i_n + 1).to(tl.int32) + seq_length = eos - bos + else: + bos = 0 + seq_block_id = prog_id + seq_length = T + + head_id = head_pid + seq_prog_id = seq_block_id + + batch_id_i64 = batch_id.to(tl.int64) + head_id_i64 = head_id.to(tl.int64) + seq_prog_id_i64 = seq_prog_id.to(tl.int64) + bos_i64 = bos.to(tl.int64) + + batch_offset = batch_id_i64 * T + head_offset = (batch_offset + bos_i64) * H + head_id_i64 + block_offset = seq_prog_id_i64 * B * T * H + + DO_head_seq_ptr = DO_ptr + head_offset * head_size + DR_head_seq_ptr = DR_ptr + head_offset + A_head_seq_ptr = A_ptr + head_offset + Q_head_seq_ptr = Q_ptr + head_offset * head_size + K_head_seq_ptr = K_ptr + head_offset * head_size + V_head_seq_ptr = V_ptr + head_offset * head_size + DQ_head_seq_ptr = DQ_ptr + head_offset * head_size + DK_head_seq_ptr = DK_ptr + (block_offset + head_offset) * head_size + DV_head_seq_ptr = DV_ptr + (block_offset + head_offset) * head_size + + stickbreaking_attn_bwd_one_row_kernel( + seq_prog_id, + seq_length, + qk_scale, + M_range, + N_range, + D_range, + D_mask, + cm, + DO_head_seq_ptr, + DR_head_seq_ptr, + A_head_seq_ptr, + Q_head_seq_ptr, + K_head_seq_ptr, + V_head_seq_ptr, + DQ_head_seq_ptr, + DK_head_seq_ptr, + DV_head_seq_ptr, + scale, + head_size, + H, + BLOCK_D, + NO_D_MASK, + NO_M_MASK, + NO_N_MASK, + ALLOW_TF32, + BT, + BS, + acc_dtype, + is_compiling=is_compiling, + ) + + +def parallel_stickbreaking_attn_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -554,30 +560,30 @@ def stickbreaking_attn_fwd( q, k, v: [B, T, H, D] Returns: o [B, T, H, D], rem [B, T, H], neg_log_acc [B, T, H] """ - batch_size, token_size, num_heads, dim_size = q.size() + B, T, H, D = q.size() o = torch.empty_like(q) rem = torch.zeros_like(q[:, :, :, 0], device=q.device) neg_log_acc = torch.zeros_like(rem, device=q.device, dtype=torch.float32) - BLOCK_M = 64 - BLOCK_N = 64 + BT = 64 + BS = 64 if cu_seqlens is None: - num_seq_blocks = triton.cdiv(token_size, BLOCK_M) - grid = (batch_size, num_heads, num_seq_blocks) + NT = triton.cdiv(T, BT) + grid = (B, H, NT) CI = None else: - CI = prepare_chunk_indices(cu_seqlens, BLOCK_M) - num_seq_blocks = int(CI.shape[0]) - grid = (1, num_heads, num_seq_blocks) - BLOCK_D = triton.next_power_of_2(dim_size) + CI = prepare_chunk_indices(cu_seqlens, BT) + NT = int(CI.shape[0]) + grid = (1, H, NT) + BLOCK_D = triton.next_power_of_2(D) - NO_M_MASK = (token_size % BLOCK_M) == 0 - NO_N_MASK = (token_size % BLOCK_N) == 0 + NO_M_MASK = (T % BT) == 0 + NO_N_MASK = (T % BS) == 0 if cu_seqlens is not None: NO_M_MASK = False NO_N_MASK = False - stickbreaking_attn_fwd_kernel[grid]( + parallel_stickbreaking_attn_fwd_kernel[grid]( q, k, v, @@ -587,18 +593,17 @@ def stickbreaking_attn_fwd( CU_ptr=cu_seqlens if cu_seqlens is not None else q, CI_ptr=CI if CI is not None else q, scale=scale, - batch_size=batch_size, - token_size=token_size, - head_size=dim_size, - num_heads=num_heads, + B=B, + T=T, + head_size=D, + H=H, BLOCK_D=BLOCK_D, - NO_D_MASK=dim_size == BLOCK_D, + NO_D_MASK=D == BLOCK_D, NO_M_MASK=NO_M_MASK, NO_N_MASK=NO_N_MASK, ALLOW_TF32=ALLOW_TF32, - inv_log2=inv_log2, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, + BT=BT, + BS=BS, no_grad=False, is_compiling=False, IS_VARLEN=cu_seqlens is not None, @@ -607,7 +612,7 @@ def stickbreaking_attn_fwd( return o, rem, neg_log_acc -def stickbreaking_attn_bwd( +def parallel_stickbreaking_attn_bwd( do: torch.Tensor, dr: torch.Tensor, q: torch.Tensor, @@ -617,30 +622,30 @@ def stickbreaking_attn_bwd( scale: float, cu_seqlens: torch.LongTensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - batch_size, token_size, num_heads, dim_size = q.size() - BLOCK_M = 64 - BLOCK_N = 64 + B, T, H, D = q.size() + BT = 64 + BS = 64 if cu_seqlens is None: - M_count = triton.cdiv(token_size, BLOCK_M) - grid = (batch_size, num_heads, M_count) + M_count = triton.cdiv(T, BT) + grid = (B, H, M_count) CI = None else: - CI = prepare_chunk_indices(cu_seqlens, BLOCK_M) + CI = prepare_chunk_indices(cu_seqlens, BT) M_count = int(CI.shape[0]) - grid = (1, num_heads, M_count) + grid = (1, H, M_count) dq = torch.zeros_like(q) - dk = torch.zeros((M_count, batch_size, token_size, num_heads, dim_size), dtype=k.dtype, device=k.device) - dv = torch.zeros((M_count, batch_size, token_size, num_heads, dim_size), dtype=v.dtype, device=v.device) + dk = torch.zeros((M_count, B, T, H, D), dtype=k.dtype, device=k.device) + dv = torch.zeros((M_count, B, T, H, D), dtype=v.dtype, device=v.device) - BLOCK_D = triton.next_power_of_2(dim_size) + BLOCK_D = triton.next_power_of_2(D) - NO_M_MASK = (token_size % BLOCK_M) == 0 - NO_N_MASK = (token_size % BLOCK_N) == 0 + NO_M_MASK = (T % BT) == 0 + NO_N_MASK = (T % BS) == 0 if cu_seqlens is not None: NO_M_MASK = False NO_N_MASK = False - stickbreaking_attn_bwd_kernel[grid]( + parallel_stickbreaking_attn_bwd_kernel[grid]( do, dr, neg_log_acc, @@ -653,18 +658,17 @@ def stickbreaking_attn_bwd( CU_ptr=cu_seqlens if cu_seqlens is not None else q, CI_ptr=CI if CI is not None else q, scale=scale, - batch_size=batch_size, - token_size=token_size, - head_size=dim_size, - num_heads=num_heads, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, + B=B, + T=T, + head_size=D, + H=H, + BT=BT, + BS=BS, BLOCK_D=BLOCK_D, - NO_D_MASK=dim_size == BLOCK_D, + NO_D_MASK=D == BLOCK_D, NO_M_MASK=NO_M_MASK, NO_N_MASK=NO_N_MASK, ALLOW_TF32=ALLOW_TF32, - inv_log2=inv_log2, acc_dtype=tl.float32, is_compiling=False, IS_VARLEN=cu_seqlens is not None, @@ -677,7 +681,11 @@ def stickbreaking_attn_bwd( class StickBreakingAttentionFunction(torch.autograd.Function): + + @staticmethod @staticmethod + @contiguous + @autocast_custom_fwd def forward( ctx, q: torch.Tensor, @@ -686,16 +694,18 @@ def forward( scale: float, cu_seqlens: torch.LongTensor | None = None, ): - o, rem, neg_log_acc = stickbreaking_attn_fwd(q, k, v, scale, cu_seqlens) + o, rem, neg_log_acc = parallel_stickbreaking_attn_fwd(q, k, v, scale, cu_seqlens) ctx.save_for_backward(q, k, v, neg_log_acc) ctx.scale = scale ctx.cu_seqlens = cu_seqlens return o, rem @staticmethod + @contiguous + @autocast_custom_bwd def backward(ctx, do: torch.Tensor, drem: torch.Tensor): q, k, v, neg_log_acc = ctx.saved_tensors - dq, dk, dv = stickbreaking_attn_bwd(do, drem, q, k, v, neg_log_acc, ctx.scale, ctx.cu_seqlens) + dq, dk, dv = parallel_stickbreaking_attn_bwd(do, drem, q, k, v, neg_log_acc, ctx.scale, ctx.cu_seqlens) return dq, dk, dv, None, None @@ -707,7 +717,9 @@ def parallel_stickbreaking_attn( cu_seqlens: torch.LongTensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if scale is None: - scale = q.shape[-1] ** -0.5 + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" return StickBreakingAttentionFunction.apply(q, k, v, scale, cu_seqlens) From 3d037ff12882ce3d44ed66792930372c2a788839 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 10 Nov 2025 15:41:01 +0000 Subject: [PATCH 10/10] Add PTX softplus --- fla/ops/stickbreaking_attn/parallel.py | 19 ++-------- fla/ops/stickbreaking_attn/softplus.py | 49 ++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 17 deletions(-) create mode 100644 fla/ops/stickbreaking_attn/softplus.py diff --git a/fla/ops/stickbreaking_attn/parallel.py b/fla/ops/stickbreaking_attn/parallel.py index 3daa87ac5..8ac9c7bd0 100644 --- a/fla/ops/stickbreaking_attn/parallel.py +++ b/fla/ops/stickbreaking_attn/parallel.py @@ -4,17 +4,13 @@ import triton import triton.language as tl +from fla.ops.stickbreaking_attn.softplus import softplus from fla.ops.utils.index import prepare_chunk_indices from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous ALLOW_TF32 = True -@triton.jit -def softplus(x, is_compiling: tl.constexpr = False): - return tl.where(x < 15.0, tl.math.log2(1 + tl.math.exp2(x)), x) - - @triton.jit def load_kv(K_blk_ptrs, V_blk_ptrs, N_mask, NO_N_MASK, D_mask, NO_D_MASK: tl.constexpr): if NO_D_MASK: @@ -44,10 +40,9 @@ def compute_block( ALLOW_TF32: tl.constexpr, backward: tl.constexpr, use_cumsum: tl.constexpr = False, - is_compiling: tl.constexpr = False, ): qk = tl.dot(q, tl.trans(k), allow_tf32=ALLOW_TF32) * qk_scale - log_om_beta = -softplus(qk, is_compiling=is_compiling) + log_om_beta = -softplus(qk) if on_band: block_mask = M_blk_idxs[:, None] > N_blk_idxs[None, :] @@ -103,7 +98,6 @@ def stickbreaking_attn_fwd_one_row_kernel( no_grad: tl.constexpr = False, acc_dtype: tl.constexpr = tl.float32, return_attention: tl.constexpr = False, - is_compiling: tl.constexpr = False, ): block_start_offset = BT * seq_block_id M_blk_idxs = block_start_offset + M_range @@ -165,7 +159,6 @@ def stickbreaking_attn_fwd_one_row_kernel( on_band, ALLOW_TF32, backward=False, - is_compiling=is_compiling, use_cumsum=False, ) acc = tl.dot(p.to(v.dtype), v, acc, allow_tf32=ALLOW_TF32) @@ -212,7 +205,6 @@ def stickbreaking_attn_bwd_one_row_kernel( BT: tl.constexpr, BS: tl.constexpr, acc_dtype: tl.constexpr = tl.float32, - is_compiling: tl.constexpr = False, ): block_start_offset = BT * seq_prog_id M_blk_idxs = block_start_offset + M_range @@ -292,7 +284,6 @@ def stickbreaking_attn_bwd_one_row_kernel( on_band, ALLOW_TF32, backward=True, - is_compiling=is_compiling, ) if not NO_M_MASK: @@ -363,7 +354,6 @@ def parallel_stickbreaking_attn_fwd_kernel( BS: tl.constexpr, no_grad: tl.constexpr = False, acc_dtype: tl.constexpr = tl.float32, - is_compiling: tl.constexpr = False, IS_VARLEN: tl.constexpr = False, ): tl.static_assert(BT % BS == 0) @@ -427,7 +417,6 @@ def parallel_stickbreaking_attn_fwd_kernel( no_grad, acc_dtype, False, - is_compiling=is_compiling, ) @@ -465,7 +454,6 @@ def parallel_stickbreaking_attn_bwd_kernel( BT: tl.constexpr, BS: tl.constexpr, acc_dtype: tl.constexpr = tl.float32, - is_compiling: tl.constexpr = False, IS_VARLEN: tl.constexpr = False, ): tl.static_assert(BT % BS == 0) @@ -543,7 +531,6 @@ def parallel_stickbreaking_attn_bwd_kernel( BT, BS, acc_dtype, - is_compiling=is_compiling, ) @@ -605,7 +592,6 @@ def parallel_stickbreaking_attn_fwd( BT=BT, BS=BS, no_grad=False, - is_compiling=False, IS_VARLEN=cu_seqlens is not None, ) @@ -670,7 +656,6 @@ def parallel_stickbreaking_attn_bwd( NO_N_MASK=NO_N_MASK, ALLOW_TF32=ALLOW_TF32, acc_dtype=tl.float32, - is_compiling=False, IS_VARLEN=cu_seqlens is not None, ) diff --git a/fla/ops/stickbreaking_attn/softplus.py b/fla/ops/stickbreaking_attn/softplus.py new file mode 100644 index 000000000..20d306718 --- /dev/null +++ b/fla/ops/stickbreaking_attn/softplus.py @@ -0,0 +1,49 @@ +# COPIED FROM +# https://github.com/shawntan/stickbreaking-attention/blob/main/stickbreaking_attention/sb_varlen/softplus.py + +import triton +from triton import language as tl + + +def _generate_asm(num_pack): + template = """ + .reg .pred p; + setp.gt.f32 p, ${in_reg}, 15.; + @p mov.f32 ${out_reg}, ${in_reg}; + @!p ex2.approx.ftz.f32 ${out_reg}, ${in_reg}; + @!p add.f32 ${out_reg}, ${out_reg}, 1.0; + @!p lg2.approx.ftz.f32 ${out_reg}, ${out_reg}; + """ + out_str = "" + + for i in range(num_pack): + inner_str = template.format(out_reg=i, in_reg=i + num_pack) + out_str += "{" + inner_str + "}\n" + # flatten out because torch.compile doesn't like newlines + out_str = " ".join(out_str.split("\n")) + return out_str + + +def _generate_constraints(num_pack): + return ",".join("=r" for i in range(num_pack)) + "," + ",".join("r" for i in range(num_pack)) + + +_NUM_REG = 1 +asm_str: tl.constexpr = tl.constexpr(_generate_asm(_NUM_REG)) +constraints_str: tl.constexpr = tl.constexpr(_generate_constraints(_NUM_REG)) +NUM_REG: tl.constexpr = tl.constexpr(_NUM_REG) + + +@triton.jit +def softplus(x): + # return tl.where(x < 15.0, tl.math.log2(1 + tl.math.exp2(x)), x) + return tl.inline_asm_elementwise( + asm=asm_str, + constraints=constraints_str, + pack=NUM_REG, + args=[ + x, + ], + dtype=tl.float32, + is_pure=True, + )