Skip to content

Commit edc8fac

Browse files
authored
Merge pull request #79 from orionpapadakis/opt/q8-load-bytearray
[Opt] Manipulation of Q8_0 tensors with Tornado `ByteArray`s
2 parents 7f9c5c6 + 6b66a59 commit edc8fac

30 files changed

+435
-326
lines changed

pom.xml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,12 @@
5454
<dependency>
5555
<groupId>io.github.beehive-lab</groupId>
5656
<artifactId>tornado-api</artifactId>
57-
<version>2.0.0</version>
57+
<version>2.0.1-dev</version>
5858
</dependency>
5959
<dependency>
6060
<groupId>io.github.beehive-lab</groupId>
6161
<artifactId>tornado-runtime</artifactId>
62-
<version>2.0.0</version>
62+
<version>2.0.1-dev</version>
6363
</dependency>
6464
</dependencies>
6565

src/main/java/org/beehive/gpullama3/inference/InferenceCore.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,24 @@ public static FloatArray forwardTornadoVM(Model model, State state, int token, i
583583
final Configuration configuration = model.configuration();
584584
final TornadoWeights weights = (TornadoWeights) model.weights();
585585

586-
MemorySegment.copy(weights.getTokenEmbeddingTable().asFloatArray().getSegment(), (long) token * configuration.dim() * Float.BYTES, state.wrapX.getSegment(), 0, configuration.dim() * Float.BYTES);
586+
switch (weights.getWeightType()) {
587+
case F16 -> {
588+
MemorySegment tokenEmbeddings = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment();
589+
int bytes = Short.BYTES;
590+
MemorySegment.copy(tokenEmbeddings, (long) token * configuration.dim() * bytes, state.embeddingX.getSegment(), 0, (long) configuration.dim() * bytes);
591+
}
592+
case Q8_0 -> {
593+
MemorySegment tokenEmbeddings = weights.getTokenEmbeddingTable().asByteArray().getSegment();
594+
int blockSize = 32;
595+
int Q8_0_BLOCK_BYTES = 34; // 2 bytes scale + 32 bytes quants
596+
int blocksPerToken = (configuration.dim() + blockSize - 1) / blockSize; // Ceiling division
597+
long bytesPerToken = (long) blocksPerToken * Q8_0_BLOCK_BYTES;
598+
599+
MemorySegment.copy(tokenEmbeddings, (long) token * bytesPerToken, state.embeddingX.getSegment(), 0, bytesPerToken);
600+
601+
}
602+
default -> throw new IllegalArgumentException("Unsupported weight type: " + weights.getWeightType());
603+
}
587604

588605
return tornadoVMMasterPlan.tornadoVMForwardExecuteLayered(position);
589606
}

src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor;
44
import org.beehive.gpullama3.tensor.standard.FloatTensor;
55
import org.beehive.gpullama3.model.Configuration;
6+
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
67
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
8+
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
79
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
810

911
import java.util.stream.Stream;
@@ -52,6 +54,11 @@ protected StateFields createStateFields(Configuration config) {
5254
fields.wrapHb = new FloatArray(config.hiddenDim());
5355
fields.wrapHb2 = new FloatArray(config.hiddenDim());
5456

57+
switch (config.quantization()) {
58+
case "FP16" -> fields.createActivationFP16(config.dim());
59+
case "Q8_0" -> fields.createActivationQ8_0(config.dim());
60+
default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization());
61+
}
5562
fields.wrapLogits = new FloatArray(config.vocabularySize());
5663
fields.wrapQ = new FloatArray(config.dim());
5764
fields.wrapK = new FloatArray(config.dim());

src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import org.beehive.gpullama3.model.Configuration;
66
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
77
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
8+
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
89
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
910

1011
import java.util.stream.Stream;
@@ -79,6 +80,11 @@ protected StateFields createStateFields(Configuration config) {
7980
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(contextLength, kvDim)).limit(nLayers).toArray(FloatTensor[]::new);
8081

8182
// TornadoVM wrapper arrays for GPU acceleration
83+
switch (config.quantization()) {
84+
case "FP16" -> fields.createActivationFP16(config.dim());
85+
case "Q8_0" -> fields.createActivationQ8_0(config.dim());
86+
default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization());
87+
}
8288
fields.wrapX = new FloatArray(dim);
8389
fields.wrapXb = new FloatArray(dim);
8490
fields.wrapXb2 = new FloatArray(dim);

src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import org.beehive.gpullama3.model.Configuration;
66
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
77
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
8+
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
89
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
910

1011
import java.util.stream.Stream;
@@ -40,6 +41,11 @@ protected StateFields createStateFields(Configuration configuration) {
4041
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
4142

4243
// TornadoVM wrappers with Qwen2 dimensions
44+
switch (config.quantization()) {
45+
case "FP16" -> fields.createActivationFP16(config.dim());
46+
case "Q8_0" -> fields.createActivationQ8_0(config.dim());
47+
default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization());
48+
}
4349
fields.wrapX = new FloatArray(config.dim());
4450
fields.wrapXb = new FloatArray(config.dim());
4551
fields.wrapXb2 = new FloatArray(config.dim());

src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import org.beehive.gpullama3.model.Configuration;
66
import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
77
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
8+
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
89
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
910

1011
import java.util.stream.Stream;
@@ -65,6 +66,13 @@ protected StateFields createStateFields(Configuration configuration) {
6566
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
6667

6768
// TornadoVM wrappers with Qwen3-specific sizes
69+
70+
switch (config.quantization()) {
71+
case "FP16" -> fields.createActivationFP16(config.dim());
72+
case "Q8_0" -> fields.createActivationQ8_0(config.dim());
73+
default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization());
74+
}
75+
6876
fields.wrapX = new FloatArray(config.dim());
6977
fields.wrapXb = new FloatArray(nEmbdHeadK * config.numberOfHeads());
7078
fields.wrapXb2 = new FloatArray(config.dim());
@@ -74,7 +82,6 @@ protected StateFields createStateFields(Configuration configuration) {
7482
fields.wrapQ = new FloatArray(nEmbdHeadK * config.numberOfHeads());
7583
fields.wrapK = new FloatArray(nEmbdKGqa);
7684
fields.wrapV = new FloatArray(nEmbdKGqa);
77-
7885
fields.wrapKeyCache = new FloatArray(config.contextLength() * nEmbdGqa * config.numberOfLayers());
7986
fields.wrapValueCache = new FloatArray(config.contextLength() * nEmbdGqa * config.numberOfLayers());
8087
fields.wrapValueCache.init(0.f);

src/main/java/org/beehive/gpullama3/inference/state/State.java

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import org.beehive.gpullama3.tensor.standard.FloatTensor;
44
import org.beehive.gpullama3.model.Configuration;
5-
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
6-
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
5+
import uk.ac.manchester.tornado.api.types.HalfFloat;
6+
import uk.ac.manchester.tornado.api.types.arrays.*;
77

88
/**
99
* Represents the base state structure used during LLM inference.
@@ -57,6 +57,7 @@ public abstract class State {
5757
public final FloatArray wrapValueCache; // FloatArray wrapper for the value cache, optimized for TornadoVM.
5858
public final IntArray positionHolder;
5959

60+
public TornadoNativeArray embeddingX;
6061
// store inter
6162
public int localSize;
6263
public FloatArray temp; // Temporary buffer for intermediate calculations, size adjusted for local workgroup size.
@@ -88,6 +89,7 @@ protected State(Configuration config, int batchsize) {
8889
this.keyCache = fields.keyCache;
8990
this.valueCache = fields.valueCache;
9091

92+
this.embeddingX = fields.embeddingX;
9193
this.wrapX = fields.wrapX;
9294
this.wrapXb = fields.wrapXb;
9395
this.wrapXb2 = fields.wrapXb2;
@@ -121,6 +123,19 @@ protected static class StateFields {
121123
public FloatArray wrapQ, wrapK, wrapV, wrapAtt, wrapKeyCache, wrapValueCache;
122124
public IntArray positionHolder;
123125
public FloatArray temp, tempFFN, tempLogits;
126+
public TornadoNativeArray embeddingX;
127+
128+
public void createActivationFP16(int size) {
129+
this.embeddingX = new HalfFloatArray(size);
130+
}
131+
132+
public void createActivationQ8_0(int size) {
133+
int blockSize = 32;
134+
int Q8_0_BLOCK_BYTES = 34; // 2 bytes scale + 32 bytes quants
135+
int blocksNeeded = (size + blockSize - 1) / blockSize;
136+
int q8BytesNeeded = blocksNeeded * Q8_0_BLOCK_BYTES;
137+
this.embeddingX = new ByteArray(q8BytesNeeded);
138+
}
124139
}
125140

126141
@Override

src/main/java/org/beehive/gpullama3/model/Configuration.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
public interface Configuration {
44

5+
String quantization();
6+
57
/** Transformer embedding dimension */
68
int dim();
79

src/main/java/org/beehive/gpullama3/model/llama/LlamaConfiguration.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import org.beehive.gpullama3.model.Configuration;
44

55
// @formatter:off
6-
public record LlamaConfiguration(int dim,
6+
public record LlamaConfiguration(String quantization,
7+
int dim,
78
int hiddenDim,
89
int numberOfLayers,
910
int numberOfHeads,
@@ -13,6 +14,11 @@ public record LlamaConfiguration(int dim,
1314
float rmsNormEps,
1415
float ropeTheta) implements Configuration {
1516

17+
@Override
18+
public String quantization() {
19+
return quantization;
20+
}
21+
1622
@Override
1723
public int numberOfHeadsKey() {
1824
throw new UnsupportedOperationException("Not supported for Llama.");
@@ -51,6 +57,7 @@ public LlamaConfiguration withContextLength(int newContextLength) {
5157
return this; // no change
5258
}
5359
return new LlamaConfiguration(
60+
this.quantization,
5461
this.dim,
5562
this.hiddenDim,
5663
this.numberOfLayers,

src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,15 @@ protected AbstractModelLoader(FileChannel fileChannel, GGUF gguf, int contextLen
3535
this.useTornadovm = useTornadovm;
3636
}
3737

38+
protected String getModelQuantization(Map<String, Object> metadata) {
39+
int modelQuantizationAsInt = (int) metadata.get("general.file_type");
40+
return switch (modelQuantizationAsInt) {
41+
case 1 -> "FP16";
42+
case 7 -> "Q8_0";
43+
default -> throw new UnsupportedOperationException("Unsupported quantization format: " + modelQuantizationAsInt + " (as int).");
44+
};
45+
}
46+
3847
/**
3948
* Template method that defines the model loading workflow. Subclasses should not override this method.
4049
*

0 commit comments

Comments
 (0)