From 5d01b70d6e46c8cb422617965c512c1fd542bd4f Mon Sep 17 00:00:00 2001 From: Espere-1119-Song Date: Wed, 13 Aug 2025 22:54:45 -0700 Subject: [PATCH 1/2] Modify output shape in nsa for decoding --- fla/ops/nsa/parallel.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fla/ops/nsa/parallel.py b/fla/ops/nsa/parallel.py index d1974cfe8..b386a830d 100644 --- a/fla/ops/nsa/parallel.py +++ b/fla/ops/nsa/parallel.py @@ -542,7 +542,7 @@ def parallel_nsa_fwd( token_indices: Optional[torch.LongTensor] = None, ): B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] - HQ = q.shape[2] + _, T_q, HQ, _ = q.shape G = HQ // H BS = block_size if check_shared_mem('hopper', q.device.index): @@ -555,9 +555,9 @@ def parallel_nsa_fwd( NV = triton.cdiv(V, BV) assert NK == 1, "The key dimension can not be larger than 256" - grid = (T, NV, B * H) - o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) - lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) + grid = (T_q, NV, B * H) + o = torch.empty(B, T_q, HQ, V, dtype=v.dtype, device=q.device) + lse = torch.empty(B, T_q, HQ, dtype=torch.float, device=q.device) parallel_nsa_fwd_kernel[grid]( q=q, From 1b3483c0bddb03b6d315de26a120510de4e11423 Mon Sep 17 00:00:00 2001 From: Espere-1119-Song Date: Wed, 13 Aug 2025 23:00:48 -0700 Subject: [PATCH 2/2] Modify output shape in nsa for decoding with the true format --- evals/harness.py | 3 ++- fla/ops/nsa/parallel.py | 4 ++-- fla/ops/simple_gla/README.md | 8 ++++---- legacy/training/README.md | 20 +++++++++---------- legacy/training/configs/gla_1B.json | 2 +- legacy/training/configs/gla_340M.json | 2 +- legacy/training/configs/gla_7B.json | 2 +- legacy/training/configs/transformer_340M.json | 2 +- legacy/training/flame/logging.py | 3 +-- legacy/training/flame/parser.py | 3 +-- legacy/training/run.py | 7 +++---- 11 files changed, 27 insertions(+), 29 deletions(-) diff --git a/evals/harness.py b/evals/harness.py index 24739c9c0..e54620783 100644 --- a/evals/harness.py +++ b/evals/harness.py @@ -2,11 +2,12 @@ from __future__ import annotations -import fla # noqa from lm_eval.__main__ import cli_evaluate from lm_eval.api.registry import register_model from lm_eval.models.huggingface import HFLM +import fla # noqa + @register_model('fla') class FlashLinearAttentionLMWrapper(HFLM): diff --git a/fla/ops/nsa/parallel.py b/fla/ops/nsa/parallel.py index b386a830d..e5a76a3d7 100644 --- a/fla/ops/nsa/parallel.py +++ b/fla/ops/nsa/parallel.py @@ -542,7 +542,7 @@ def parallel_nsa_fwd( token_indices: Optional[torch.LongTensor] = None, ): B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] - _, T_q, HQ, _ = q.shape + _, T_q, HQ, _ = q.shape G = HQ // H BS = block_size if check_shared_mem('hopper', q.device.index): @@ -556,7 +556,7 @@ def parallel_nsa_fwd( assert NK == 1, "The key dimension can not be larger than 256" grid = (T_q, NV, B * H) - o = torch.empty(B, T_q, HQ, V, dtype=v.dtype, device=q.device) + o = torch.empty(B, T_q, HQ, V, dtype=v.dtype, device=q.device) lse = torch.empty(B, T_q, HQ, dtype=torch.float, device=q.device) parallel_nsa_fwd_kernel[grid]( diff --git a/fla/ops/simple_gla/README.md b/fla/ops/simple_gla/README.md index 2a64f3dcd..c359ced5e 100644 --- a/fla/ops/simple_gla/README.md +++ b/fla/ops/simple_gla/README.md @@ -1,10 +1,10 @@ # Simple GLA -Gating mechanism in [Gated RFA](https://arxiv.org/abs/2103.02143), [Mamba2](https://arxiv.org/abs/2405.21060) and [YOCO](https://arxiv.org/abs/2405.05254) (a.k.a., Gated RetNet). +Gating mechanism in [Gated RFA](https://arxiv.org/abs/2103.02143), [Mamba2](https://arxiv.org/abs/2405.21060) and [YOCO](https://arxiv.org/abs/2405.05254) (a.k.a., Gated RetNet). -Compared to GLA, the gating is head-wise instead of elementwise. -As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability. -It is faster than GLA but has less expressive power. +Compared to GLA, the gating is head-wise instead of elementwise. +As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability. +It is faster than GLA but has less expressive power. I will use it as a baseline for the GLA. $S_{t+1} = g_{t+1} \odot S_{t} + K_{t+1} V_{t+1}^{\top}$ where $g$ is a scalar. diff --git a/legacy/training/README.md b/legacy/training/README.md index 13bbc1730..9a21de234 100644 --- a/legacy/training/README.md +++ b/legacy/training/README.md @@ -7,14 +7,14 @@ > [!IMPORTANT] > The `flame` project has been migrated to a new project built on torchtitan. > Please visit the [new repository](https://github.com/fla-org/flame) for details and updates. -> +> > The code here is now **archived as legacy**, and no future updates will be synchronized here. A minimal framework for training FLA models, whether from scratch or through finetuning. Built on the robust infrastructure of 🤗, `flame` enables you to train large language models with just a few lines of code: we use `datasets` for data processing, `transformers` for model definitions, and `accelerate`[^1] for seamless distributed training. - + In this README, we will guide you through the process of using `flame` to train GLA models. ## Setup @@ -25,7 +25,7 @@ Clone the `fla` repository and install the necessary packages as follows: ```bash git clone https://github.com/sustcsonglin/flash-linear-attention.git -pip install . +pip install . pip install accelerate ``` @@ -35,8 +35,8 @@ pip install accelerate ## Preprocessing -Before training, you need to download and pre-tokenize your dataset. -We provide a straightforward script for this. +Before training, you need to download and pre-tokenize your dataset. +We provide a straightforward script for this. For instance, to tokenize a 10B sample of the `fineweb-edu` dataset, run: ```bash @@ -103,15 +103,15 @@ Other scheduler types like WSD (`warmup_stable_decay`)[^2] are also supported. The total number of tokens processed per batch, referred to as `global_batch_size`, is calculated as `batch_size × gradient_accumulation_steps × context_length × num_gpus_per_node × num_nodes`. -For instance, in the 340M model example, the `global_batch_size` calculates to $32 \times 1 \times 2048 \times 8 \times 1 = 524,288$ (0.5M tokens). +For instance, in the 340M model example, the `global_batch_size` calculates to $32 \times 1 \times 2048 \times 8 \times 1 = 524,288$ (0.5M tokens). The `warmup_steps` parameter indicates the number of steps for the learning rate warmup phase, while `max_steps` represents the maximum number of training steps. -Each step processes `global_batch_size` tokens. +Each step processes `global_batch_size` tokens. Consequently, `512` and `20480` correspond to processing 0.5B and 10B tokens, respectively. :warning: Monitor the value of `global_batch_size`, `warmup_steps`, and `max_steps` carefully when modifying any of the hyperparameters!! -`flame` also supports resuming interrupted training by specifying the checkpoint path. +`flame` also supports resuming interrupted training by specifying the checkpoint path. Simply use the following command: ```bash @@ -141,7 +141,7 @@ You can also use `wandb` to monitor your training process effectively. ## Continual Pretraining `flame` supports continual training from a pretrained checkpoint. -Below, we provide an example of how to finetune Mistral-7B to GLA. +Below, we provide an example of how to finetune Mistral-7B to GLA. You can follow similar steps to reproduce the results in the [GSA paper](https://arxiv.org/abs/2409.07146): 1. Initialize a brand-new GLA-7B model from the config and copy the mathced pretrained weights from Mistral-7B: @@ -171,7 +171,7 @@ bash train.sh \ cache=data/SlimPajama-627B/train ``` -Please be aware that finetuning on a single node may not be the most efficient approach. +Please be aware that finetuning on a single node may not be the most efficient approach. If available, consider leveraging multi-node GPUs for optimal performance. You can find guidance on how to launch a multi-node job in the [accelerate tutorial](https://github.com/huggingface/accelerate/blob/main/examples/slurm/submit_multinode.sh). diff --git a/legacy/training/configs/gla_1B.json b/legacy/training/configs/gla_1B.json index eed54325e..95ef59945 100644 --- a/legacy/training/configs/gla_1B.json +++ b/legacy/training/configs/gla_1B.json @@ -22,4 +22,4 @@ "use_gk": true, "use_gv": false, "vocab_size": 32000 -} \ No newline at end of file +} diff --git a/legacy/training/configs/gla_340M.json b/legacy/training/configs/gla_340M.json index 378d80e70..bcb3fc3b0 100644 --- a/legacy/training/configs/gla_340M.json +++ b/legacy/training/configs/gla_340M.json @@ -21,4 +21,4 @@ "use_gk": true, "use_gv": false, "vocab_size": 32000 -} \ No newline at end of file +} diff --git a/legacy/training/configs/gla_7B.json b/legacy/training/configs/gla_7B.json index ca5658aab..c321d3d72 100644 --- a/legacy/training/configs/gla_7B.json +++ b/legacy/training/configs/gla_7B.json @@ -25,4 +25,4 @@ "use_gk": true, "use_gv": false, "vocab_size": 32000 -} \ No newline at end of file +} diff --git a/legacy/training/configs/transformer_340M.json b/legacy/training/configs/transformer_340M.json index e703797ca..08356de26 100644 --- a/legacy/training/configs/transformer_340M.json +++ b/legacy/training/configs/transformer_340M.json @@ -15,4 +15,4 @@ "tie_word_embeddings": true, "use_cache": true, "vocab_size": 32000 -} \ No newline at end of file +} diff --git a/legacy/training/flame/logging.py b/legacy/training/flame/logging.py index 0b5ebe3d3..9b572d6aa 100644 --- a/legacy/training/flame/logging.py +++ b/legacy/training/flame/logging.py @@ -6,8 +6,7 @@ import sys import time -from transformers.trainer_callback import (ExportableState, TrainerCallback, - TrainerControl, TrainerState) +from transformers.trainer_callback import ExportableState, TrainerCallback, TrainerControl, TrainerState from transformers.training_args import TrainingArguments diff --git a/legacy/training/flame/parser.py b/legacy/training/flame/parser.py index 3b54d76e2..921fcb4d9 100644 --- a/legacy/training/flame/parser.py +++ b/legacy/training/flame/parser.py @@ -6,9 +6,8 @@ from typing import Optional import transformers -from transformers import HfArgumentParser, TrainingArguments - from flame.logging import get_logger +from transformers import HfArgumentParser, TrainingArguments logger = get_logger(__name__) diff --git a/legacy/training/run.py b/legacy/training/run.py index 0689d28fa..151324919 100644 --- a/legacy/training/run.py +++ b/legacy/training/run.py @@ -1,13 +1,12 @@ # -*- coding: utf-8 -*- from datasets import load_from_disk -from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, - Trainer) - -import fla # noqa from flame.data import DataCollatorForLanguageModeling from flame.logging import LogCallback, get_logger from flame.parser import get_train_args +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, Trainer + +import fla # noqa logger = get_logger(__name__)