Skip to content

Commit 15d3989

Browse files
committed
Fix Q cache for TP mode
1 parent 4c6dc58 commit 15d3989

File tree

6 files changed

+60
-32
lines changed

6 files changed

+60
-32
lines changed

exllamav2/attn.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -614,8 +614,6 @@ def forward_paged_tp(
614614
cfg = self.model.config
615615
ctx = self.model.tp_context
616616

617-
assert cache.q_block != 1, \
618-
"Models with odd key/value dims not supported in TP mode with quantized cache"
619617
assert not self.sliding_window, \
620618
"Sliding window not supported in TP mode"
621619

@@ -631,7 +629,7 @@ def forward_paged_tp(
631629
self.layer_idx,
632630
batch_size,
633631
0,
634-
attn_params.max_cache_seqlen if cache.q_block > 1 else 0,
632+
attn_params.max_cache_seqlen,
635633
page_size,
636634
attn_params.cache_seqlens_tp,
637635
attn_params.block_index_tp
@@ -706,7 +704,7 @@ def forward_paged_tp_old(
706704
self.layer_idx,
707705
batch_size,
708706
0,
709-
attn_params.max_cache_seqlen if cache.q_block > 1 else 0,
707+
attn_params.max_cache_seqlen,
710708
page_size,
711709
attn_params.cache_seqlens_tp,
712710
attn_params.block_index_tp
@@ -1171,7 +1169,7 @@ def forward_tp(
11711169
)
11721170

11731171
if cache is not None:
1174-
cache.store_kv_state(self.layer_idx, batch_size, 0, q_len)
1172+
cache.store_kv_state(self.layer_idx, batch_size, past_len, q_len)
11751173

11761174
return ctx.get_pinned(0, batch_size, q_len, cfg.hidden_size)
11771175

exllamav2/cache.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,8 +432,11 @@ def __init__(
432432
# Models with odd key/value dims need to quantize/dequantize in multi-token blocks. Make sure the quant
433433
# blocksize aligns with a whole number of tokens
434434

435+
if not num_key_value_heads:
436+
num_key_value_heads = cfg.num_key_value_heads
437+
435438
Q_CACHE_BLOCKSIZE_Q = 512
436-
kv_dim = cfg.num_key_value_heads * cfg.head_dim
439+
kv_dim = num_key_value_heads * cfg.head_dim
437440
self.q_block = 1
438441
while (kv_dim * self.q_block) % Q_CACHE_BLOCKSIZE_Q:
439442
self.q_block += 1

exllamav2/exllamav2_ext/cuda/cache.cu

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,10 @@ __global__ void fp16_to_q_kv_paged_kernel
172172
int px_a = seqlen - vx_a;
173173
int px_b = px_a + q_len;
174174

175-
if (dim < BLOCKSIZE_Q)
175+
if (dim % BLOCKSIZE_Q)
176176
{
177-
int g = BLOCKSIZE_Q / dim;
178-
// if (px_a > 0) DBGI4(px_a, px_b, px_a / g * g, DIVIDE(px_b, g) * g);
179-
px_a = px_a / g * g;
180-
px_b = DIVIDE(px_b, g) * g;
177+
while ((px_a * dim) % BLOCKSIZE_Q) px_a--;
178+
while ((px_b * dim) % BLOCKSIZE_Q) px_b++;
181179
}
182180

183181
px_a = max(px_a, 0);
@@ -372,10 +370,8 @@ __global__ void q_to_fp16_kv_paged_kernel
372370

373371
if (dim < BLOCKSIZE_Q)
374372
{
375-
int g = BLOCKSIZE_Q / dim;
376-
// if (vx_a > 0) DBGI4(vx_a, vx_b, vx_a / g * g, DIVIDE(vx_b, g) * g);
377-
vx_a = vx_a / g * g;
378-
vx_b = DIVIDE(vx_b, g) * g;
373+
while ((vx_a * dim) % BLOCKSIZE_Q) vx_a--;
374+
while ((vx_b * dim) % BLOCKSIZE_Q) vx_b++;
379375
}
380376

381377
int vnum = max(vx_b - vx_a, 0);

exllamav2/exllamav2_ext/ext_cache.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,14 @@ void fp16_to_q_kv
155155
int stride = k_in.size(1) * k_in.size(2) * k_in.size(3);
156156
int height = batch_size;
157157

158-
int tsize = k_in.size(2) * k_in.size(3);
159-
offset *= tsize;
160-
width *= tsize;
158+
int dim = k_in.size(2) * k_in.size(3);
159+
if (dim % Q_CACHE_BLOCKSIZE_Q)
160+
{
161+
while ((offset * dim) % Q_CACHE_BLOCKSIZE_Q) offset--;
162+
while ((width * dim) % Q_CACHE_BLOCKSIZE_Q) width++;
163+
}
164+
offset *= dim;
165+
width *= dim;
161166

162167
array_fp16_to_q_kv_cuda
163168
(
@@ -168,7 +173,7 @@ void fp16_to_q_kv
168173
(const half*) v_in.data_ptr(),
169174
(unsigned char*) v_out.data_ptr(),
170175
(half*) v_scales.data_ptr(),
171-
tsize,
176+
dim,
172177
stride,
173178
height,
174179
offset,
@@ -257,9 +262,14 @@ void q_to_fp16_kv
257262
int stride = k_out.size(1) * k_out.size(2) * k_out.size(3);
258263
int height = batch_size;
259264

260-
int tsize = k_out.size(2) * k_out.size(3);
261-
offset *= tsize;
262-
width *= tsize;
265+
int dim = k_out.size(2) * k_out.size(3);
266+
if (dim % Q_CACHE_BLOCKSIZE_Q)
267+
{
268+
while ((offset * dim) % Q_CACHE_BLOCKSIZE_Q) offset--;
269+
while ((width * dim) % Q_CACHE_BLOCKSIZE_Q) width++;
270+
}
271+
offset *= dim;
272+
width *= dim;
263273

264274
array_q_to_fp16_kv_cuda
265275
(
@@ -270,7 +280,7 @@ void q_to_fp16_kv
270280
(const unsigned char*) v_in.data_ptr(),
271281
(const half*) v_scales.data_ptr(),
272282
(half*) v_out.data_ptr(),
273-
tsize,
283+
dim,
274284
stride,
275285
height,
276286
offset,

exllamav2/model.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,10 @@ def set_device_map(self,
221221

222222
self.device_context = []
223223
for idx, scratch_bytes in enumerate(fixed_bytes):
224-
self.device_context.append(ExLlamaV2DeviceContext(self, idx, scratch_bytes))
224+
if scratch_bytes > 0:
225+
self.device_context.append(ExLlamaV2DeviceContext(self, idx, scratch_bytes))
226+
else:
227+
self.device_context.append(None)
225228

226229
# Create map for cache
227230

@@ -300,7 +303,8 @@ def load_tp(
300303
callback: Callable[[int, int], None] | None = None,
301304
callback_gen: Callable[[int, int], None] | None = None,
302305
progress: bool = False,
303-
expect_cache_tokens: int = 0
306+
expect_cache_tokens: int = 0,
307+
expect_cache_base: type = None
304308
):
305309

306310
if progress:
@@ -313,7 +317,7 @@ def callback_pb(a, b):
313317
assert callback is None, \
314318
"Cannot use callback function and console progress bar at the same time."
315319
callback = callback_pb
316-
f = self.load_tp_gen(gpu_split, callback, callback_gen, expect_cache_tokens)
320+
f = self.load_tp_gen(gpu_split, callback, callback_gen, expect_cache_tokens, expect_cache_base)
317321
for item in f:
318322
pass
319323
if progress:
@@ -325,10 +329,11 @@ def load_tp_gen(
325329
gpu_split: list[float] | None = None,
326330
callback: Callable[[int, int], None] | None = None,
327331
callback_gen: Callable[[int, int], None] | None = None,
328-
expect_cache_tokens: int = 0
332+
expect_cache_tokens: int = 0,
333+
expect_cache_base: type = None
329334
):
330335
self.config.no_graphs = True
331-
self.tp_context = TPContext(self, gpu_split, expect_cache_tokens)
336+
self.tp_context = TPContext(self, gpu_split, expect_cache_tokens, expect_cache_base)
332337

333338
# Create device tensors
334339

exllamav2/tensor_p.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def __init__(
5050
self,
5151
model: ExLlamaV2,
5252
gpu_split: list[float] | None,
53-
expect_cache_tokens: int = 0
53+
expect_cache_tokens: int = 0,
54+
expect_cache_base: type = None
5455
):
5556
self.model = model
5657
cfg = self.model.config
@@ -80,7 +81,7 @@ def __init__(
8081
self.sin = None
8182
self.cos = None
8283

83-
self.define_split(gpu_split, expect_cache_tokens)
84+
self.define_split(gpu_split, expect_cache_tokens, expect_cache_base)
8485

8586

8687
def unload(self):
@@ -98,7 +99,12 @@ def all_devices(self) -> list[int]:
9899
return sorted(devs)
99100

100101

101-
def define_split(self, gpu_split: list[float] | None, expect_cache_tokens):
102+
def define_split(
103+
self,
104+
gpu_split: list[float] | None,
105+
expect_cache_tokens: int,
106+
expect_cache_base: type
107+
):
102108
cfg = self.model.config
103109

104110
if gpu_split is None:
@@ -119,8 +125,18 @@ def define_split(self, gpu_split: list[float] | None, expect_cache_tokens):
119125

120126
if not expect_cache_tokens:
121127
expect_cache_tokens = cfg.max_seq_len * cfg.max_batch_size
128+
if expect_cache_base == sys.modules["exllamav2.cache"].ExLlamaV2Cache_8bit:
129+
bytes_per_element = 1
130+
elif expect_cache_base == sys.modules["exllamav2.cache"].ExLlamaV2Cache_Q8:
131+
bytes_per_element = 8.5/8
132+
elif expect_cache_base == sys.modules["exllamav2.cache"].ExLlamaV2Cache_Q6:
133+
bytes_per_element = 6.5/8
134+
elif expect_cache_base == sys.modules["exllamav2.cache"].ExLlamaV2Cache_Q4:
135+
bytes_per_element = 4.5/8
136+
else:
137+
bytes_per_element = 2
122138

123-
cache_size = 2 * 2 * cfg.num_key_value_heads * cfg.head_dim * cfg.num_hidden_layers * expect_cache_tokens
139+
cache_size = 2 * bytes_per_element * cfg.num_key_value_heads * cfg.head_dim * cfg.num_hidden_layers * expect_cache_tokens
124140
gpu_split = [max(0, gs - int(cache_size * r / 1024**2)) for gs, r in zip(gpu_split, attn_ratio)]
125141

126142
# Subtract size of attn layers

0 commit comments

Comments
 (0)