Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,6 @@ COPY server/Makefile-awq Makefile
# Build specific version of transformers
RUN . .venv/bin/activate && make build-awq

# Build Lorax Punica kernels
FROM kernel-builder AS lorax-punica-builder
WORKDIR /usr/src
COPY server/Makefile-lorax-punica Makefile
# Build specific version of transformers
RUN . .venv/bin/activate && TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica

# Build Transformers CUDA kernels
FROM kernel-builder AS custom-kernels-builder
WORKDIR /usr/src
Expand Down Expand Up @@ -210,8 +203,6 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311
COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from awq kernels builder
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from lorax punica kernels builder
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from mamba builder
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages
Expand Down
16 changes: 8 additions & 8 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
};
nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:huggingface/text-generation-inference-nix/torch-2.7";
tgi-nix.url = "github:huggingface/text-generation-inference-nix/merge-with-kernel-builder";
nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
rust-overlay = {
Expand Down
2 changes: 2 additions & 0 deletions nix/client.nix
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
buildPythonPackage,
poetry-core,
aiohttp,
huggingface-hub,
pydantic,
}:
Expand All @@ -15,6 +16,7 @@ buildPythonPackage {
build-system = [ poetry-core ];

dependencies = [
aiohttp
huggingface-hub
pydantic
];
Expand Down
4 changes: 2 additions & 2 deletions nix/server.nix
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
peft,
pillow,
prometheus-client,
punica-kernels,
punica-sgmv,
py-cpuinfo,
pydantic,
quantization,
Expand Down Expand Up @@ -107,7 +107,7 @@ buildPythonPackage {
peft
pillow
prometheus-client
punica-kernels
punica-sgmv
py-cpuinfo
pydantic
quantization
Expand Down
1 change: 0 additions & 1 deletion server/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ include Makefile-flash-att-v2
include Makefile-vllm
include Makefile-awq
include Makefile-selective-scan
include Makefile-lorax-punica
include Makefile-exllamav2
include Makefile-flashinfer

Expand Down
12 changes: 0 additions & 12 deletions server/Makefile-lorax-punica

This file was deleted.

58 changes: 58 additions & 0 deletions server/kernels.lock
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,64 @@
}
}
},
{
"repo_id": "kernels-community/punica-sgmv",
"sha": "9ae1b469cb39c33df9ddd61657c6359acc423714",
"variants": {
"torch26-cxx11-cu118-x86_64-linux": {
"hash": "sha256-766062cd845bdebbe4e4391fda6f2663bebc2c110cbc2642d09c8c09ccf3f1d4",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu124-x86_64-linux": {
"hash": "sha256-c9cd76df7c84851aa566deb1c0d04ebddc1b1908a29df218344f2b3d53c4e683",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu126-aarch64-linux": {
"hash": "sha256-ae444bf53be3d469d4c9c58faef7d61a92e873e6104afe5aed2b2a1397333e99",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu126-x86_64-linux": {
"hash": "sha256-0706cc5ccf9cedae0bb6a938acdf2d5599a7b8f8b1fe46118b6ad61c0f3432af",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu118-x86_64-linux": {
"hash": "sha256-42cf390c6ae48b18041e201d4c67b4bf820b9f9cafe49a12c505f7920bae56ae",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu124-x86_64-linux": {
"hash": "sha256-75c97c23bfe32f65830341420d093a07df051828f385cbc5357b073c635f442f",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu126-aarch64-linux": {
"hash": "sha256-2ff5590ff6c298220c6e06142c971b08a686b98abb8d7dd1e6eb4539fa115cba",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu126-x86_64-linux": {
"hash": "sha256-70bcf04490865df6518c9d6a4c7eb2fee76b14642651f04a061c20ffa6fdb283",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu118-x86_64-linux": {
"hash": "sha256-727b8f5b22e4e91b956516235f26c39013a87ac6e196a0ce5f1897c2d959e69d",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu126-aarch64-linux": {
"hash": "sha256-bfddd19db7c9268a83e3cc5e281b007de80ab0fe611b3856ffd1691b400eca46",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu126-x86_64-linux": {
"hash": "sha256-940c68f5d4d8a2391b1eb3c7c5a56623428862f428aa5c6c1f7e62588c0e36fb",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu128-aarch64-linux": {
"hash": "sha256-781259a371b67bfbf744431c88a6ee847ab48459e73cb57264590de2728d6b3a",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu128-x86_64-linux": {
"hash": "sha256-8977a33d7884bebb9fb5e3d7daf157119206f0f18a22edb2b96ec593d5c81ae1",
"hash_type": "git_lfs_concat"
}
}
},
{
"repo_id": "kernels-community/quantization",
"sha": "6470f9b005797e00279eb9103463dfe0f8b7da00",
Expand Down
1 change: 1 addition & 0 deletions server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ build-backend = "setuptools.build_meta"
[tool.kernels.dependencies]
"kernels-community/paged-attention" = ">=0.0.2"
"kernels-community/moe" = ">=0.1.1"
"kernels-community/punica-sgmv" = ">=0.0.1"
"kernels-community/quantization" = ">=0.0.3"
"kernels-community/quantization-eetq" = ">=0.0.1"
"kernels-community/rotary" = ">=0.0.1"
Expand Down
41 changes: 23 additions & 18 deletions server/text_generation_server/adapters/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,20 @@
from text_generation_server.utils.log import log_master

from text_generation_server.adapters.config import AdapterConfig, ModuleMap

from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.adapters.weights import (
AdapterBatchMetadata,
AdapterWeights,
BatchAdapterWeights,
)
from text_generation_server.utils.sgmv import (
BGMV_MAX_RANK,
MAX_RANK_CUSTOM,
get_tmp_tensors,
orient_for_rank,
pad_rank,
use_cutlass_shrink,
has_sgmv,
)

if SYSTEM == "cuda":
punica_sgmv = load_kernel(
module="punica_sgmv", repo_id="kernels-community/punica-sgmv"
)
else:
punica_sgmv = None


def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
Expand Down Expand Up @@ -129,11 +128,13 @@ def __init__(
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1

self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)
self._use_cutlass_shrink = punica_sgmv.use_cutlass_shrink(self.lora_a_r)
self._is_transposed = False

# [num_layers, hidden_size, r]
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
weights_a = [
punica_sgmv.orient_for_rank(w, w.size(1)).contiguous() for w in weights_a
]
self._weights_a = torch.stack(weights_a)

# [num_layers, r, hidden_size]
Expand Down Expand Up @@ -244,8 +245,12 @@ def prepare_weights(
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale

# pad lora ranks to be compatible with sgmv
lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list]
lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list]
lora_a_list = [
punica_sgmv.pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list
]
lora_b_list = [
punica_sgmv.pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list
]

if lora_a_list:
# update rank if it was padded
Expand Down Expand Up @@ -293,7 +298,7 @@ def has_adapter(self, adapter_index: int) -> bool:

def can_vectorize(self, pg: ProcessGroup) -> bool:
return all(
rank_data.rank // pg.size() <= MAX_RANK_CUSTOM
rank_data.rank // pg.size() <= punica_sgmv.MAX_RANK_CUSTOM
for rank_data in self.rank_data.values()
)

Expand Down Expand Up @@ -337,8 +342,8 @@ def load(
)

use_sgmv = False
if prefill or max_rank > BGMV_MAX_RANK:
if has_sgmv():
if prefill or max_rank > punica_sgmv.BGMV_MAX_RANK:
if punica_sgmv is not None:
use_sgmv = True
lora_a_ptr = torch.tensor(
[
Expand Down Expand Up @@ -425,7 +430,7 @@ def load(

if use_sgmv:
lora_a_ptr_indices = lora_a_ptr[indices]
tmp_shrink, tmp_expand = get_tmp_tensors(
tmp_shrink, tmp_expand = punica_sgmv.get_tmp_tensors(
lora_a_ptr_indices.size(0), rank, device
)
segment_starts = meta.adapter_segments[indices]
Expand Down
34 changes: 20 additions & 14 deletions server/text_generation_server/layers/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
from torch import nn
from torch.distributed import ProcessGroup

from text_generation_server.utils.sgmv import (
add_lora_a_bgmv,
add_lora_b_bgmv,
has_sgmv,
lora_a_sgmv_cutlass,
lora_b_sgmv_cutlass,
orient_for_rank,
)
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel

if SYSTEM == "cuda":
punica_sgmv = load_kernel(
module="punica_sgmv", repo_id="kernels-community/punica-sgmv"
)
else:
punica_sgmv = None


if TYPE_CHECKING:
from text_generation_server.adapters import AdapterBatchData
Expand Down Expand Up @@ -41,7 +43,11 @@ def forward_layer_type(
return result
data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)

if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
if (
punica_sgmv is not None
and data is not None
and data.can_vectorize(self.process_group)
):
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
# The 'result' tensor represents the full output, which can vary in size based on
# the layer type (e.g., attention vs. feed-forward layers). We define the current
Expand All @@ -68,7 +74,7 @@ def forward_layer_type(

if data.use_sgmv:
# Use SGMV for prefill
v = lora_a_sgmv_cutlass(
v = punica_sgmv.lora_a_sgmv_cutlass(
input,
rank_segments.tmp_shrink,
lora_a_ptr,
Expand All @@ -81,7 +87,7 @@ def forward_layer_type(
if self.process_group.size() > 1:
v = self.collect_lora_a(v)

lora_b_sgmv_cutlass(
punica_sgmv.lora_b_sgmv_cutlass(
proj,
v,
rank_segments.tmp_expand,
Expand All @@ -96,7 +102,7 @@ def forward_layer_type(
(input.size(0), r), dtype=input.dtype, device=input.device
)
# TODO: error with [-1, 0], but not [0, -1]
add_lora_a_bgmv(
punica_sgmv.add_lora_a_bgmv(
v,
input,
lora_a_ptr,
Expand All @@ -107,7 +113,7 @@ def forward_layer_type(
if self.process_group.size() > 1:
v = self.collect_lora_a(v)

add_lora_b_bgmv(
punica_sgmv.add_lora_b_bgmv(
proj,
v,
lora_b_ptr,
Expand Down Expand Up @@ -142,7 +148,7 @@ def forward_lora(
lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
lora_b = data.lora_b[adapter_index][self.layer_id, :, :]

lora_a = orient_for_rank(lora_a, lora_b.size(0))
lora_a = punica_sgmv.orient_for_rank(lora_a, lora_b.size(0))

a_out = input @ lora_a
if self.process_group.size() > 1:
Expand Down
Loading
Loading