Skip to content

Commit 7456d59

Browse files
Add convertQ8_0toFP32 kernel for dequantization in TransformerComputeKernels
1 parent 78d6a18 commit 7456d59

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import uk.ac.manchester.tornado.api.KernelContext;
44
import uk.ac.manchester.tornado.api.math.TornadoMath;
55
import uk.ac.manchester.tornado.api.types.HalfFloat;
6+
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
67
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
78
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
89

@@ -26,6 +27,39 @@ public static void convertFP16toFP32(KernelContext context, HalfFloatArray x, Fl
2627
wrapX.set(i, x.get(i).getFloat32());
2728
}
2829

30+
public static void convertQ8_0toFP32(KernelContext context, ByteArray x, FloatArray wrapX) {
31+
int globalId = context.globalIdx;
32+
int totalElements = wrapX.getSize();
33+
34+
if (globalId >= totalElements) {
35+
return;
36+
}
37+
38+
// Q8_0 block structure constants
39+
int blockSize = 32;
40+
int Q8_0_BLOCK_BYTES = 34; // 2 bytes scale + 32 bytes quants
41+
42+
// Calculate which block and position within block
43+
int blockIdx = globalId / blockSize;
44+
int withinBlockIdx = globalId % blockSize;
45+
46+
// Calculate byte offset for this Q8_0 block
47+
int blockByteOffset = blockIdx * Q8_0_BLOCK_BYTES;
48+
49+
// Load scale (first 2 bytes of block as HalfFloat)
50+
HalfFloat scale = x.getHalf(blockByteOffset);
51+
float scaleFloat = scale.getFloat32();
52+
53+
// Load quantized value (skip 2-byte scale, then index within block)
54+
byte quantValue = x.get(blockByteOffset + 2 + withinBlockIdx);
55+
56+
// Dequantize: float_value = quantized_value * scale
57+
float dequantizedValue = ((float) quantValue) * scaleFloat;
58+
59+
// Store result in output FloatArray
60+
wrapX.set(globalId, dequantizedValue);
61+
}
62+
2963
public static void convertFP32toFP16(KernelContext context, FloatArray wrapX, HalfFloatArray x) {
3064
int i = context.globalIdx;
3165
float valInput = wrapX.get(i);

0 commit comments

Comments
 (0)