Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion evals/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from __future__ import annotations

import fla # noqa
from lm_eval.__main__ import cli_evaluate
from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM

import fla # noqa


@register_model('fla')
class FlashLinearAttentionLMWrapper(HFLM):
Expand Down
8 changes: 4 additions & 4 deletions fla/ops/nsa/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def parallel_nsa_fwd(
token_indices: Optional[torch.LongTensor] = None,
):
B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
HQ = q.shape[2]
_, T_q, HQ, _ = q.shape
G = HQ // H
BS = block_size
if check_shared_mem('hopper', q.device.index):
Expand All @@ -555,9 +555,9 @@ def parallel_nsa_fwd(
NV = triton.cdiv(V, BV)
assert NK == 1, "The key dimension can not be larger than 256"

grid = (T, NV, B * H)
o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
grid = (T_q, NV, B * H)
o = torch.empty(B, T_q, HQ, V, dtype=v.dtype, device=q.device)
lse = torch.empty(B, T_q, HQ, dtype=torch.float, device=q.device)

parallel_nsa_fwd_kernel[grid](
q=q,
Expand Down
8 changes: 4 additions & 4 deletions fla/ops/simple_gla/README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Simple GLA

Gating mechanism in [Gated RFA](https://arxiv.org/abs/2103.02143), [Mamba2](https://arxiv.org/abs/2405.21060) and [YOCO](https://arxiv.org/abs/2405.05254) (a.k.a., Gated RetNet).
Gating mechanism in [Gated RFA](https://arxiv.org/abs/2103.02143), [Mamba2](https://arxiv.org/abs/2405.21060) and [YOCO](https://arxiv.org/abs/2405.05254) (a.k.a., Gated RetNet).

Compared to GLA, the gating is head-wise instead of elementwise.
As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability.
It is faster than GLA but has less expressive power.
Compared to GLA, the gating is head-wise instead of elementwise.
As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability.
It is faster than GLA but has less expressive power.
I will use it as a baseline for the GLA.

$S_{t+1} = g_{t+1} \odot S_{t} + K_{t+1} V_{t+1}^{\top}$ where $g$ is a scalar.
20 changes: 10 additions & 10 deletions legacy/training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
> [!IMPORTANT]
> The `flame` project has been migrated to a new project built on torchtitan.
> Please visit the [new repository](https://github.com/fla-org/flame) for details and updates.
>
>
> The code here is now **archived as legacy**, and no future updates will be synchronized here.

A minimal framework for training FLA models, whether from scratch or through finetuning.

Built on the robust infrastructure of 🤗, `flame` enables you to train large language models with just a few lines of code:
we use `datasets` for data processing, `transformers` for model definitions, and `accelerate`[^1] for seamless distributed training.

In this README, we will guide you through the process of using `flame` to train GLA models.

## Setup
Expand All @@ -25,7 +25,7 @@ Clone the `fla` repository and install the necessary packages as follows:

```bash
git clone https://github.com/sustcsonglin/flash-linear-attention.git
pip install .
pip install .
pip install accelerate
```

Expand All @@ -35,8 +35,8 @@ pip install accelerate

## Preprocessing

Before training, you need to download and pre-tokenize your dataset.
We provide a straightforward script for this.
Before training, you need to download and pre-tokenize your dataset.
We provide a straightforward script for this.
For instance, to tokenize a 10B sample of the `fineweb-edu` dataset, run:

```bash
Expand Down Expand Up @@ -103,15 +103,15 @@ Other scheduler types like WSD (`warmup_stable_decay`)[^2] are also supported.

The total number of tokens processed per batch, referred to as `global_batch_size`, is calculated as
`batch_size × gradient_accumulation_steps × context_length × num_gpus_per_node × num_nodes`.
For instance, in the 340M model example, the `global_batch_size` calculates to $32 \times 1 \times 2048 \times 8 \times 1 = 524,288$ (0.5M tokens).
For instance, in the 340M model example, the `global_batch_size` calculates to $32 \times 1 \times 2048 \times 8 \times 1 = 524,288$ (0.5M tokens).

The `warmup_steps` parameter indicates the number of steps for the learning rate warmup phase, while `max_steps` represents the maximum number of training steps.
Each step processes `global_batch_size` tokens.
Each step processes `global_batch_size` tokens.
Consequently, `512` and `20480` correspond to processing 0.5B and 10B tokens, respectively.

:warning: Monitor the value of `global_batch_size`, `warmup_steps`, and `max_steps` carefully when modifying any of the hyperparameters!!

`flame` also supports resuming interrupted training by specifying the checkpoint path.
`flame` also supports resuming interrupted training by specifying the checkpoint path.
Simply use the following command:

```bash
Expand Down Expand Up @@ -141,7 +141,7 @@ You can also use `wandb` to monitor your training process effectively.
## Continual Pretraining

`flame` supports continual training from a pretrained checkpoint.
Below, we provide an example of how to finetune Mistral-7B to GLA.
Below, we provide an example of how to finetune Mistral-7B to GLA.
You can follow similar steps to reproduce the results in the [GSA paper](https://arxiv.org/abs/2409.07146):

1. Initialize a brand-new GLA-7B model from the config and copy the mathced pretrained weights from Mistral-7B:
Expand Down Expand Up @@ -171,7 +171,7 @@ bash train.sh \
cache=data/SlimPajama-627B/train
```

Please be aware that finetuning on a single node may not be the most efficient approach.
Please be aware that finetuning on a single node may not be the most efficient approach.
If available, consider leveraging multi-node GPUs for optimal performance.
You can find guidance on how to launch a multi-node job in the [accelerate tutorial](https://github.com/huggingface/accelerate/blob/main/examples/slurm/submit_multinode.sh).

Expand Down
2 changes: 1 addition & 1 deletion legacy/training/configs/gla_1B.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@
"use_gk": true,
"use_gv": false,
"vocab_size": 32000
}
}
2 changes: 1 addition & 1 deletion legacy/training/configs/gla_340M.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
"use_gk": true,
"use_gv": false,
"vocab_size": 32000
}
}
2 changes: 1 addition & 1 deletion legacy/training/configs/gla_7B.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@
"use_gk": true,
"use_gv": false,
"vocab_size": 32000
}
}
2 changes: 1 addition & 1 deletion legacy/training/configs/transformer_340M.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
"tie_word_embeddings": true,
"use_cache": true,
"vocab_size": 32000
}
}
3 changes: 1 addition & 2 deletions legacy/training/flame/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import sys
import time

from transformers.trainer_callback import (ExportableState, TrainerCallback,
TrainerControl, TrainerState)
from transformers.trainer_callback import ExportableState, TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments


Expand Down
3 changes: 1 addition & 2 deletions legacy/training/flame/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from typing import Optional

import transformers
from transformers import HfArgumentParser, TrainingArguments

from flame.logging import get_logger
from transformers import HfArgumentParser, TrainingArguments

logger = get_logger(__name__)

Expand Down
7 changes: 3 additions & 4 deletions legacy/training/run.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# -*- coding: utf-8 -*-

from datasets import load_from_disk
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
Trainer)

import fla # noqa
from flame.data import DataCollatorForLanguageModeling
from flame.logging import LogCallback, get_logger
from flame.parser import get_train_args
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, Trainer

import fla # noqa

logger = get_logger(__name__)

Expand Down