66import uk .ac .manchester .tornado .api .types .HalfFloat ;
77import uk .ac .manchester .tornado .api .types .arrays .HalfFloatArray ;
88import uk .ac .manchester .tornado .api .types .arrays .Int8Array ;
9+ import uk .ac .manchester .tornado .api .types .arrays .TornadoNativeArray ;
910
1011import java .lang .foreign .MemorySegment ;
1112import java .lang .foreign .ValueLayout ;
1213import java .nio .ByteOrder ;
1314
1415public class Q8_0TornadoTensor extends TornadoTensor {
1516
17+ private final int size ;
1618 private final HalfFloatArray scales ; // One per 32-element block
1719 private final Int8Array quants ; // Quantized int8 values
1820 private MemorySegment segment ;
1921
20- public Q8_0TornadoTensor (HalfFloatArray scales , Int8Array quants , MemorySegment segment ) {
22+ public Q8_0TornadoTensor (int size , HalfFloatArray scales , Int8Array quants , MemorySegment segment ) {
23+ this .size = size ;
2124 this .scales = scales ;
2225 this .quants = quants ;
2326 this .segment = segment ;
2427 }
2528
29+ public int getSize () {
30+ return size ;
31+ }
32+
2633 /**
2734 * Returns the scale factors for GPU kernels.
2835 *
@@ -77,7 +84,10 @@ public static Q8_0TornadoTensor create(GGMLTensorEntry entry) {
7784 throw new IllegalArgumentException ("Q8_0 tensor size must be multiple of " + GGMLType .Q8_0 .getBlockSize () + ", got: " + size + " for tensor: " + entry .name ());
7885 }
7986
80- MemorySegment q8Segment = entry .memorySegment ();
87+ // TODO: fix Q8_0 loading in tornado layoyt
88+ // currently we end up to hack it by removing
89+ // tornado header from memory segment
90+ MemorySegment q8Segment = entry .memorySegment ().asSlice (TornadoNativeArray .ARRAY_HEADER );
8191
8292 // allocate the arrays for quantized data (int8) and scales (fp16)
8393 HalfFloatArray scales = new HalfFloatArray (numBlocks );
@@ -103,6 +113,6 @@ public static Q8_0TornadoTensor create(GGMLTensorEntry entry) {
103113 }
104114 }
105115
106- return new Q8_0TornadoTensor (scales , quants , q8Segment );
116+ return new Q8_0TornadoTensor (size , scales , quants , q8Segment );
107117 }
108118}
0 commit comments