|
13 | 13 | from text_generation_server.utils.log import log_master |
14 | 14 |
|
15 | 15 | from text_generation_server.adapters.config import AdapterConfig, ModuleMap |
16 | | - |
| 16 | +from text_generation_server.utils.import_utils import SYSTEM |
| 17 | +from text_generation_server.utils.kernels import load_kernel |
17 | 18 | from text_generation_server.adapters.weights import ( |
18 | 19 | AdapterBatchMetadata, |
19 | 20 | AdapterWeights, |
20 | 21 | BatchAdapterWeights, |
21 | 22 | ) |
22 | | -from text_generation_server.utils.sgmv import ( |
23 | | - BGMV_MAX_RANK, |
24 | | - MAX_RANK_CUSTOM, |
25 | | - get_tmp_tensors, |
26 | | - orient_for_rank, |
27 | | - pad_rank, |
28 | | - use_cutlass_shrink, |
29 | | - has_sgmv, |
30 | | -) |
| 23 | + |
| 24 | +if SYSTEM == "cuda": |
| 25 | + punica_sgmv = load_kernel( |
| 26 | + module="punica_sgmv", repo_id="kernels-community/punica-sgmv" |
| 27 | + ) |
| 28 | +else: |
| 29 | + punica_sgmv = None |
31 | 30 |
|
32 | 31 |
|
33 | 32 | def get_start_stop_idxs_for_rank(offset, size, rank, world_size): |
@@ -129,11 +128,13 @@ def __init__( |
129 | 128 | self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1 |
130 | 129 | self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1 |
131 | 130 |
|
132 | | - self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r) |
| 131 | + self._use_cutlass_shrink = punica_sgmv.use_cutlass_shrink(self.lora_a_r) |
133 | 132 | self._is_transposed = False |
134 | 133 |
|
135 | 134 | # [num_layers, hidden_size, r] |
136 | | - weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a] |
| 135 | + weights_a = [ |
| 136 | + punica_sgmv.orient_for_rank(w, w.size(1)).contiguous() for w in weights_a |
| 137 | + ] |
137 | 138 | self._weights_a = torch.stack(weights_a) |
138 | 139 |
|
139 | 140 | # [num_layers, r, hidden_size] |
@@ -244,8 +245,12 @@ def prepare_weights( |
244 | 245 | lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale |
245 | 246 |
|
246 | 247 | # pad lora ranks to be compatible with sgmv |
247 | | - lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list] |
248 | | - lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list] |
| 248 | + lora_a_list = [ |
| 249 | + punica_sgmv.pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list |
| 250 | + ] |
| 251 | + lora_b_list = [ |
| 252 | + punica_sgmv.pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list |
| 253 | + ] |
249 | 254 |
|
250 | 255 | if lora_a_list: |
251 | 256 | # update rank if it was padded |
@@ -293,7 +298,7 @@ def has_adapter(self, adapter_index: int) -> bool: |
293 | 298 |
|
294 | 299 | def can_vectorize(self, pg: ProcessGroup) -> bool: |
295 | 300 | return all( |
296 | | - rank_data.rank // pg.size() <= MAX_RANK_CUSTOM |
| 301 | + rank_data.rank // pg.size() <= punica_sgmv.MAX_RANK_CUSTOM |
297 | 302 | for rank_data in self.rank_data.values() |
298 | 303 | ) |
299 | 304 |
|
@@ -337,8 +342,8 @@ def load( |
337 | 342 | ) |
338 | 343 |
|
339 | 344 | use_sgmv = False |
340 | | - if prefill or max_rank > BGMV_MAX_RANK: |
341 | | - if has_sgmv(): |
| 345 | + if prefill or max_rank > punica_sgmv.BGMV_MAX_RANK: |
| 346 | + if punica_sgmv is not None: |
342 | 347 | use_sgmv = True |
343 | 348 | lora_a_ptr = torch.tensor( |
344 | 349 | [ |
@@ -425,7 +430,7 @@ def load( |
425 | 430 |
|
426 | 431 | if use_sgmv: |
427 | 432 | lora_a_ptr_indices = lora_a_ptr[indices] |
428 | | - tmp_shrink, tmp_expand = get_tmp_tensors( |
| 433 | + tmp_shrink, tmp_expand = punica_sgmv.get_tmp_tensors( |
429 | 434 | lora_a_ptr_indices.size(0), rank, device |
430 | 435 | ) |
431 | 436 | segment_starts = meta.adapter_segments[indices] |
|
0 commit comments