Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e6ce9e0
Refactor tensor loading and introduce support for Half-Float precisio…
mikepapadim Nov 26, 2025
db30dba
Replace `loadTornadoTensorAsFP32` with `loadTornadoTensor` across mod…
mikepapadim Nov 26, 2025
553015d
Add `modelType` to Configuration
mikepapadim Nov 26, 2025
da10c5c
Add `readModelType` integration for all model loaders
orionpapadakis Dec 4, 2025
579d6ea
Update `Q8_0` tensor creation to use `fromTornadoMemorySegment` method
orionpapadakis Dec 4, 2025
7adfd80
Add new `Q8_0TornadoTensor` constructor using `ByteArray` and `ByteAr…
orionpapadakis Dec 4, 2025
613cdd2
Add support for `Q8_0` weight type in `InferenceCore` embedding table…
orionpapadakis Dec 4, 2025
91f48b0
Change `embeddingX` type from `HalfFloatArray` to `TornadoNativeArray`
orionpapadakis Dec 4, 2025
786bdc2
Add `type` field to `LlamaConfiguration` constructor
orionpapadakis Dec 4, 2025
78d6a18
Add FP16 and Q8_0 support to `Activation` layer initialization
orionpapadakis Dec 4, 2025
7456d59
Add `convertQ8_0toFP32` kernel for dequantization in `TransformerComp…
orionpapadakis Dec 4, 2025
04db93d
Add `matrixVectorGenericQ8Byte` and `matrixVectorRowMajorOptimizedQ8_…
orionpapadakis Dec 4, 2025
dd8064e
Update FFN and attention layers to use Q8_0 byte-based kernels for ma…
orionpapadakis Dec 4, 2025
2316ca1
Update `LogitsQ8_0Layer` to use byte-based Q8_0 kernels
orionpapadakis Dec 4, 2025
56a960a
Remove deprecated methods for Q8_0 tensor loading and conversion to FP32
orionpapadakis Dec 4, 2025
9d0fb16
Add `matrixVectorGenericWithResidualQ8_0Byte` and `fusedFeedForwardWi…
orionpapadakis Dec 4, 2025
68729ee
Replace `getHalf` with `getHalfFloat` for Q8_0 block scale loading in…
orionpapadakis Dec 5, 2025
843e30c
Add FP16 and Q8_0 activation initialization methods in `State` class
orionpapadakis Dec 5, 2025
111dbdd
Use quantization-specific activation init in Llama models
orionpapadakis Dec 5, 2025
4e984fa
Use quantization-specific activation init in Qwen3 models
orionpapadakis Dec 5, 2025
ed5f882
Update Qwen3 FFN layers to use byte-based Q8_0 kernels
orionpapadakis Dec 5, 2025
4e30022
Use quantization-specific activation init in Qwen2 and Deepseek models
orionpapadakis Dec 5, 2025
c52bcaa
Update Qwen2 and Deepseek FFN layers to use byte-based Q8_0 kernels
orionpapadakis Dec 5, 2025
9562505
Use quantization-specific activation init in Phi3 models
orionpapadakis Dec 5, 2025
2c8cf24
Update Phi3 FFN layers to use byte-based Q8_0 kernels
orionpapadakis Dec 5, 2025
eccdce6
Rename `modelType` to `quantization` across configurations and update…
orionpapadakis Dec 5, 2025
4f13785
Cleanup unused memorySegment copy
orionpapadakis Dec 5, 2025
572f7b3
Cleanup and document `Q8_0TornadoTensor`
orionpapadakis Dec 5, 2025
a920424
Use Configuration.quantization() method in Activation
orionpapadakis Dec 5, 2025
6b66a59
[CI] Update Tornado dependencies to version 2.0.1-dev to run CI
orionpapadakis Dec 5, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@
<dependency>
<groupId>io.github.beehive-lab</groupId>
<artifactId>tornado-api</artifactId>
<version>2.0.0</version>
<version>2.0.1-dev</version>
</dependency>
<dependency>
<groupId>io.github.beehive-lab</groupId>
<artifactId>tornado-runtime</artifactId>
<version>2.0.0</version>
<version>2.0.1-dev</version>
</dependency>
</dependencies>

Expand Down
19 changes: 18 additions & 1 deletion src/main/java/org/beehive/gpullama3/inference/InferenceCore.java
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,24 @@ public static FloatArray forwardTornadoVM(Model model, State state, int token, i
final Configuration configuration = model.configuration();
final TornadoWeights weights = (TornadoWeights) model.weights();

MemorySegment.copy(weights.getTokenEmbeddingTable().asFloatArray().getSegment(), (long) token * configuration.dim() * Float.BYTES, state.wrapX.getSegment(), 0, configuration.dim() * Float.BYTES);
switch (weights.getWeightType()) {
case F16 -> {
MemorySegment tokenEmbeddings = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment();
int bytes = Short.BYTES;
MemorySegment.copy(tokenEmbeddings, (long) token * configuration.dim() * bytes, state.embeddingX.getSegment(), 0, (long) configuration.dim() * bytes);
}
case Q8_0 -> {
MemorySegment tokenEmbeddings = weights.getTokenEmbeddingTable().asByteArray().getSegment();
int blockSize = 32;
int Q8_0_BLOCK_BYTES = 34; // 2 bytes scale + 32 bytes quants
int blocksPerToken = (configuration.dim() + blockSize - 1) / blockSize; // Ceiling division
long bytesPerToken = (long) blocksPerToken * Q8_0_BLOCK_BYTES;

MemorySegment.copy(tokenEmbeddings, (long) token * bytesPerToken, state.embeddingX.getSegment(), 0, bytesPerToken);

}
default -> throw new IllegalArgumentException("Unsupported weight type: " + weights.getWeightType());
}

return tornadoVMMasterPlan.tornadoVMForwardExecuteLayered(position);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor;
import org.beehive.gpullama3.tensor.standard.FloatTensor;
import org.beehive.gpullama3.model.Configuration;
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;

import java.util.stream.Stream;
Expand Down Expand Up @@ -52,6 +54,11 @@ protected StateFields createStateFields(Configuration config) {
fields.wrapHb = new FloatArray(config.hiddenDim());
fields.wrapHb2 = new FloatArray(config.hiddenDim());

switch (config.quantization()) {
case "FP16" -> fields.createActivationFP16(config.dim());
case "Q8_0" -> fields.createActivationQ8_0(config.dim());
default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization());
}
fields.wrapLogits = new FloatArray(config.vocabularySize());
fields.wrapQ = new FloatArray(config.dim());
fields.wrapK = new FloatArray(config.dim());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;

import java.util.stream.Stream;
Expand Down Expand Up @@ -79,6 +80,11 @@ protected StateFields createStateFields(Configuration config) {
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(contextLength, kvDim)).limit(nLayers).toArray(FloatTensor[]::new);

// TornadoVM wrapper arrays for GPU acceleration
switch (config.quantization()) {
case "FP16" -> fields.createActivationFP16(config.dim());
case "Q8_0" -> fields.createActivationQ8_0(config.dim());
default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization());
}
fields.wrapX = new FloatArray(dim);
fields.wrapXb = new FloatArray(dim);
fields.wrapXb2 = new FloatArray(dim);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;

import java.util.stream.Stream;
Expand Down Expand Up @@ -40,6 +41,11 @@ protected StateFields createStateFields(Configuration configuration) {
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);

// TornadoVM wrappers with Qwen2 dimensions
switch (config.quantization()) {
case "FP16" -> fields.createActivationFP16(config.dim());
case "Q8_0" -> fields.createActivationQ8_0(config.dim());
default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization());
}
fields.wrapX = new FloatArray(config.dim());
fields.wrapXb = new FloatArray(config.dim());
fields.wrapXb2 = new FloatArray(config.dim());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;

import java.util.stream.Stream;
Expand Down Expand Up @@ -65,6 +66,13 @@ protected StateFields createStateFields(Configuration configuration) {
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);

// TornadoVM wrappers with Qwen3-specific sizes

switch (config.quantization()) {
case "FP16" -> fields.createActivationFP16(config.dim());
case "Q8_0" -> fields.createActivationQ8_0(config.dim());
default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization());
}

fields.wrapX = new FloatArray(config.dim());
fields.wrapXb = new FloatArray(nEmbdHeadK * config.numberOfHeads());
fields.wrapXb2 = new FloatArray(config.dim());
Expand All @@ -74,7 +82,6 @@ protected StateFields createStateFields(Configuration configuration) {
fields.wrapQ = new FloatArray(nEmbdHeadK * config.numberOfHeads());
fields.wrapK = new FloatArray(nEmbdKGqa);
fields.wrapV = new FloatArray(nEmbdKGqa);

fields.wrapKeyCache = new FloatArray(config.contextLength() * nEmbdGqa * config.numberOfLayers());
fields.wrapValueCache = new FloatArray(config.contextLength() * nEmbdGqa * config.numberOfLayers());
fields.wrapValueCache.init(0.f);
Expand Down
19 changes: 17 additions & 2 deletions src/main/java/org/beehive/gpullama3/inference/state/State.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import org.beehive.gpullama3.tensor.standard.FloatTensor;
import org.beehive.gpullama3.model.Configuration;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
import uk.ac.manchester.tornado.api.types.HalfFloat;
import uk.ac.manchester.tornado.api.types.arrays.*;

/**
* Represents the base state structure used during LLM inference.
Expand Down Expand Up @@ -57,6 +57,7 @@ public abstract class State {
public final FloatArray wrapValueCache; // FloatArray wrapper for the value cache, optimized for TornadoVM.
public final IntArray positionHolder;

public TornadoNativeArray embeddingX;
Copy link

Copilot AI Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The field name embeddingX is ambiguous. Since it stores quantized embeddings in their native format (FP16 or Q8_0), a clearer name would be quantizedEmbeddingX or nativeFormatEmbeddingX to distinguish it from the FP32 wrapX field.

Copilot uses AI. Check for mistakes.
// store inter
public int localSize;
public FloatArray temp; // Temporary buffer for intermediate calculations, size adjusted for local workgroup size.
Expand Down Expand Up @@ -88,6 +89,7 @@ protected State(Configuration config, int batchsize) {
this.keyCache = fields.keyCache;
this.valueCache = fields.valueCache;

this.embeddingX = fields.embeddingX;
this.wrapX = fields.wrapX;
this.wrapXb = fields.wrapXb;
this.wrapXb2 = fields.wrapXb2;
Expand Down Expand Up @@ -121,6 +123,19 @@ protected static class StateFields {
public FloatArray wrapQ, wrapK, wrapV, wrapAtt, wrapKeyCache, wrapValueCache;
public IntArray positionHolder;
public FloatArray temp, tempFFN, tempLogits;
public TornadoNativeArray embeddingX;

public void createActivationFP16(int size) {
this.embeddingX = new HalfFloatArray(size);
}

public void createActivationQ8_0(int size) {
int blockSize = 32;
int Q8_0_BLOCK_BYTES = 34; // 2 bytes scale + 32 bytes quants
int blocksNeeded = (size + blockSize - 1) / blockSize;
int q8BytesNeeded = blocksNeeded * Q8_0_BLOCK_BYTES;
this.embeddingX = new ByteArray(q8BytesNeeded);
}
}

@Override
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/org/beehive/gpullama3/model/Configuration.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

public interface Configuration {

String quantization();

/** Transformer embedding dimension */
int dim();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import org.beehive.gpullama3.model.Configuration;

// @formatter:off
public record LlamaConfiguration(int dim,
public record LlamaConfiguration(String quantization,
int dim,
int hiddenDim,
int numberOfLayers,
int numberOfHeads,
Expand All @@ -13,6 +14,11 @@ public record LlamaConfiguration(int dim,
float rmsNormEps,
float ropeTheta) implements Configuration {

@Override
public String quantization() {
return quantization;
}

@Override
public int numberOfHeadsKey() {
throw new UnsupportedOperationException("Not supported for Llama.");
Expand Down Expand Up @@ -51,6 +57,7 @@ public LlamaConfiguration withContextLength(int newContextLength) {
return this; // no change
}
return new LlamaConfiguration(
this.quantization,
this.dim,
this.hiddenDim,
this.numberOfLayers,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ protected AbstractModelLoader(FileChannel fileChannel, GGUF gguf, int contextLen
this.useTornadovm = useTornadovm;
}

protected String getModelQuantization(Map<String, Object> metadata) {
int modelQuantizationAsInt = (int) metadata.get("general.file_type");
return switch (modelQuantizationAsInt) {
case 1 -> "FP16";
case 7 -> "Q8_0";
Comment on lines +40 to +42
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are these magic numbers 1 & 7?

default -> throw new UnsupportedOperationException("Unsupported quantization format: " + modelQuantizationAsInt + " (as int).");
};
}

/**
* Template method that defines the model loading workflow. Subclasses should not override this method.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ protected LlamaConfiguration createConfiguration(Map<String, Object> metadata) {
int vocabSize = metadata.containsKey("llama.vocab_size") ? (int) metadata.get("llama.vocab_size") : (int) metadata.get("tokenizer.ggml.tokens.length");

return new LlamaConfiguration(
getModelQuantization(metadata),
(int) metadata.get("llama.embedding_length"),
(int) metadata.get("llama.feed_forward_length"),
(int) metadata.get("llama.block_count"),
Expand Down Expand Up @@ -120,7 +121,7 @@ protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntr

// Load all tensors uniformly as TornadoTensor hierarchy
return new LlamaTornadoWeights(
loadTornadoTensorAsFP32(tokenEmbeddings),
loadTornadoTensor(tokenEmbeddings),
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // fp32
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ protected MistralConfiguration createConfiguration(Map<String, Object> metadata)
int vocabSize = metadata.containsKey("llama.vocab_size") ? (int) metadata.get("llama.vocab_size") : (int) metadata.get("tokenizer.ggml.tokens.length");

return new MistralConfiguration(
getModelQuantization(metadata),
(int) metadata.get("llama.embedding_length"),
(int) metadata.get("llama.feed_forward_length"),
(int) metadata.get("llama.block_count"),
Expand Down Expand Up @@ -130,7 +131,7 @@ protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntr

// Load all tensors uniformly as TornadoTensor hierarchy
return new LlamaTornadoWeights(
loadTornadoTensorAsFP32(tokenEmbeddings),
loadTornadoTensor(tokenEmbeddings),
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // fp32
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public static TornadoTensor loadTornadoTensor(GGMLTensorEntry entry) {
return switch (ggmlType) {
case F32 -> FP32TornadoTensor.fromTornadoMemorySegment(entry.memorySegment());
case F16 -> FP16TornadoTensor.fromTornadoMemorySegment(entry.memorySegment());
case Q8_0 -> Q8_0TornadoTensor.createAsQ8_0(entry);
case Q8_0 -> Q8_0TornadoTensor.fromTornadoMemorySegment(entry.memorySegment());
case Q4_0 -> throw new UnsupportedOperationException("Q4 format not supported yet");
default -> throw new UnsupportedOperationException("Quantization format " + ggmlType);
};
Expand All @@ -145,31 +145,6 @@ public static TornadoTensor[] loadArrayOfTornadoTensors(int size, IntFunction<GG
return array;
}

/**
* Load a tensor and manually convert to FP32 (FloatArray).
* Used for embeddings that currently are treated as FP32.
* TODO: it is ultra-slow and should be removed
*/
public static TornadoTensor loadTornadoTensorAsFP32(GGMLTensorEntry entry) {
TornadoTensor tensor = loadTornadoTensor(entry);
return switch (tensor.type()) {
case F32 -> tensor;
case F16 -> {
HalfFloatArray tensorHFA = tensor.asHalfFloatArray();
int numOfElements = tensorHFA.getSize();
FloatArray tensorFA = new FloatArray(numOfElements);
for (int i = 0; i < numOfElements; i++) {
tensorFA.set(i, tensorHFA.get(i).getFloat32());
}
yield new FP32TornadoTensor(tensorFA);
}
case Q8_0 -> Q8_0TornadoTensor.createAsFP32(entry);
default -> {
throw new UnsupportedOperationException("Unsupported tensor type: " + tensor.type());
}
};
}

// Helper methods

public static FloatArray[] loadArrayAsFloatArray(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
Expand All @@ -188,14 +163,6 @@ public static HalfFloatArray[] loadArrayAsHalfFloatArray(int size, IntFunction<G
return array;
}

public static Q8_0TornadoTensor[] loadArrayAsQ8_0TornadoTensor(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
Q8_0TornadoTensor[] array = new Q8_0TornadoTensor[size];
for (int i = 0; i < size; i++) {
array[i] = Q8_0TornadoTensor.createAsQ8_0(getTensorEntry.apply(i));
}
return array;
}

public static FloatArray floatBufferToFloatArray(GGMLTensorEntry tensorEntry) {
if (tensorEntry.ggmlType() == GGMLType.F32) {
FloatBuffer buffer = tensorEntry.memorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ protected Phi3Configuration createConfiguration(Map<String, Object> metadata) {
final String modelPrefix = "phi3.";

var config = new Phi3Configuration(
getModelQuantization(metadata),
(int) metadata.get(modelPrefix + "embedding_length"), // dim
(int) metadata.get(modelPrefix + "feed_forward_length"), // hidden_dim
(int) metadata.get(modelPrefix + "block_count"), // n_layers
Expand Down Expand Up @@ -140,7 +141,7 @@ protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntr

// Load all tensors uniformly as TornadoTensor hierarchy
return new Phi3TornadoWeights(
loadTornadoTensorAsFP32(tokenEmbeddings),
loadTornadoTensor(tokenEmbeddings),
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // fp32
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")),
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ protected Qwen2Configuration createConfiguration(Map<String, Object> metadata) {
int vocabSize = vocabulary.size();

return new Qwen2Configuration(
getModelQuantization(metadata),
(int) metadata.get("qwen2.embedding_length"), // dim
(int) metadata.get("qwen2.feed_forward_length"), // hiddendim
(int) metadata.get("qwen2.block_count"), // numberOfLayers
Expand Down Expand Up @@ -137,7 +138,7 @@ protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntr

// Load all tensors uniformly as TornadoTensor hierarchy
return new Qwen2TornadoWeights(
loadTornadoTensorAsFP32(tokenEmbeddings),
loadTornadoTensor(tokenEmbeddings),
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // fp32
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ protected Qwen3Configuration createConfiguration(Map<String, Object> metadata) {
int vocabSize = vocabulary.size();

return new Qwen3Configuration(
getModelQuantization(metadata),
(int) metadata.get("qwen3.embedding_length"),
(int) metadata.get("qwen3.feed_forward_length"),
(int) metadata.get("qwen3.block_count"),
Expand Down Expand Up @@ -137,7 +138,7 @@ protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntr
final int nl = config.numberOfLayers();

return new Qwen3TornadoWeights(
loadTornadoTensorAsFP32(tokenEmbeddings),
loadTornadoTensor(tokenEmbeddings),
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // fp32
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
Expand Down
Loading