33import uk .ac .manchester .tornado .api .KernelContext ;
44import uk .ac .manchester .tornado .api .math .TornadoMath ;
55import uk .ac .manchester .tornado .api .types .HalfFloat ;
6+ import uk .ac .manchester .tornado .api .types .arrays .ByteArray ;
67import uk .ac .manchester .tornado .api .types .arrays .FloatArray ;
78import uk .ac .manchester .tornado .api .types .arrays .HalfFloatArray ;
89
@@ -26,6 +27,39 @@ public static void convertFP16toFP32(KernelContext context, HalfFloatArray x, Fl
2627 wrapX .set (i , x .get (i ).getFloat32 ());
2728 }
2829
30+ public static void convertQ8_0toFP32 (KernelContext context , ByteArray x , FloatArray wrapX ) {
31+ int globalId = context .globalIdx ;
32+ int totalElements = wrapX .getSize ();
33+
34+ if (globalId >= totalElements ) {
35+ return ;
36+ }
37+
38+ // Q8_0 block structure constants
39+ int blockSize = 32 ;
40+ int Q8_0_BLOCK_BYTES = 34 ; // 2 bytes scale + 32 bytes quants
41+
42+ // Calculate which block and position within block
43+ int blockIdx = globalId / blockSize ;
44+ int withinBlockIdx = globalId % blockSize ;
45+
46+ // Calculate byte offset for this Q8_0 block
47+ int blockByteOffset = blockIdx * Q8_0_BLOCK_BYTES ;
48+
49+ // Load scale (first 2 bytes of block as HalfFloat)
50+ HalfFloat scale = x .getHalf (blockByteOffset );
51+ float scaleFloat = scale .getFloat32 ();
52+
53+ // Load quantized value (skip 2-byte scale, then index within block)
54+ byte quantValue = x .get (blockByteOffset + 2 + withinBlockIdx );
55+
56+ // Dequantize: float_value = quantized_value * scale
57+ float dequantizedValue = ((float ) quantValue ) * scaleFloat ;
58+
59+ // Store result in output FloatArray
60+ wrapX .set (globalId , dequantizedValue );
61+ }
62+
2963 public static void convertFP32toFP16 (KernelContext context , FloatArray wrapX , HalfFloatArray x ) {
3064 int i = context .globalIdx ;
3165 float valInput = wrapX .get (i );
0 commit comments