Skip to content

Commit 8fb6cd1

Browse files
[hack] Fix backwards compatibility with Q8_0
1 parent fad57a9 commit 8fb6cd1

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ public static TornadoTensor[] loadArrayOfTornadoTensors(int size, IntFunction<GG
153153
/**
154154
* Load a tensor and manually convert to FP32 (FloatArray).
155155
* Used for embeddings that currently are treated as FP32.
156-
* TODO: it is ultra-slow and will be removed
156+
* TODO: it is ultra-slow and should be removed
157157
*/
158158
public static TornadoTensor loadTornadoTensorAsFP32(GGMLTensorEntry entry) {
159159
TornadoTensor tensor = loadTornadoTensor(entry);
@@ -168,6 +168,16 @@ public static TornadoTensor loadTornadoTensorAsFP32(GGMLTensorEntry entry) {
168168
}
169169
yield new FP32TornadoTensor(tensorFA);
170170
}
171+
case Q8_0 -> {
172+
Q8_0TornadoTensor tensorQ8_0 = Q8_0TornadoTensor.create(entry);
173+
int numOfElements = tensorQ8_0.getSize();
174+
FloatArray tensorFA = new FloatArray(numOfElements);
175+
for(int i = 0; i < numOfElements; i++) {
176+
tensorFA.set(i, tensorQ8_0.getFloat(i));
177+
}
178+
yield new FP32TornadoTensor(tensorFA);
179+
180+
}
171181
default -> { throw new UnsupportedOperationException("Unsupported tensor type: " + tensor.type()); }
172182
};
173183
}

src/main/java/org/beehive/gpullama3/tensor/tornado/Q8_0TornadoTensor.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,30 @@
66
import uk.ac.manchester.tornado.api.types.HalfFloat;
77
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
88
import uk.ac.manchester.tornado.api.types.arrays.Int8Array;
9+
import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
910

1011
import java.lang.foreign.MemorySegment;
1112
import java.lang.foreign.ValueLayout;
1213
import java.nio.ByteOrder;
1314

1415
public class Q8_0TornadoTensor extends TornadoTensor {
1516

17+
private final int size;
1618
private final HalfFloatArray scales; // One per 32-element block
1719
private final Int8Array quants; // Quantized int8 values
1820
private MemorySegment segment;
1921

20-
public Q8_0TornadoTensor(HalfFloatArray scales, Int8Array quants, MemorySegment segment) {
22+
public Q8_0TornadoTensor(int size, HalfFloatArray scales, Int8Array quants, MemorySegment segment) {
23+
this.size = size;
2124
this.scales = scales;
2225
this.quants = quants;
2326
this.segment = segment;
2427
}
2528

29+
public int getSize() {
30+
return size;
31+
}
32+
2633
/**
2734
* Returns the scale factors for GPU kernels.
2835
*
@@ -77,7 +84,10 @@ public static Q8_0TornadoTensor create(GGMLTensorEntry entry) {
7784
throw new IllegalArgumentException("Q8_0 tensor size must be multiple of " + GGMLType.Q8_0.getBlockSize() + ", got: " + size + " for tensor: " + entry.name());
7885
}
7986

80-
MemorySegment q8Segment = entry.memorySegment();
87+
// TODO: fix Q8_0 loading in tornado layoyt
88+
// currently we end up to hack it by removing
89+
// tornado header from memory segment
90+
MemorySegment q8Segment = entry.memorySegment().asSlice(TornadoNativeArray.ARRAY_HEADER);
8191

8292
// allocate the arrays for quantized data (int8) and scales (fp16)
8393
HalfFloatArray scales = new HalfFloatArray(numBlocks);
@@ -103,6 +113,6 @@ public static Q8_0TornadoTensor create(GGMLTensorEntry entry) {
103113
}
104114
}
105115

106-
return new Q8_0TornadoTensor(scales, quants, q8Segment);
116+
return new Q8_0TornadoTensor(size, scales, quants, q8Segment);
107117
}
108118
}

0 commit comments

Comments
 (0)