Skip to content

Commit e89dc5b

Browse files
committed
Fix for older Torch versions
1 parent 15d3989 commit e89dc5b

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

exllamav2/util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def unpack_4bit(packed: torch.Tensor):
341341

342342
m, n8 = packed.shape
343343
n = n8 * 8
344-
assert packed.dtype in [torch.int32, torch.uint32]
344+
assert packed.dtype in [torch.int32]
345345

346346
# packed = packed.view(torch.uint32)
347347
unpacked = torch.empty((m, n), dtype = torch.uint8, device = packed.device)
@@ -366,5 +366,5 @@ def pack_4bit(unpacked: torch.Tensor):
366366
packed = torch.zeros((m, n // 8), dtype = torch.int64, device = unpacked.device)
367367
for i in range(8):
368368
packed |= (unpacked[:, i::8].to(torch.int64) << (i * 4))
369-
packed = packed.to(torch.uint32)
370-
return packed.view(torch.int32)
369+
packed = packed.to(torch.int32)
370+
return packed

0 commit comments

Comments
 (0)