Skip to content

Commit dd8064e

Browse files
Update FFN and attention layers to use Q8_0 byte-based kernels for matrix-vector computations
1 parent 04db93d commit dd8064e

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -61,40 +61,45 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config,
6161
unifiedLayer.consumeFromDevice(state.wrapX);
6262
unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
6363
//Copy-in weights per layer for batched-layered layout
64-
weights.rms_att_weightLayered[layerIndex].asFloatArray(), weights.wqLayered[layerIndex].getQuants(), weights.wqLayered[layerIndex].getScales(), weights.wkLayered[layerIndex].getQuants(),
65-
weights.wkLayered[layerIndex].getScales(), weights.wvLayered[layerIndex].getQuants(), weights.wvLayered[layerIndex].getScales(), weights.woLayered[layerIndex].getQuants(),
66-
weights.woLayered[layerIndex].getScales(), weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), weights.w1Layered[layerIndex].getQuants(), weights.w1Layered[layerIndex].getScales(),
67-
weights.w2Layered[layerIndex].getQuants(), weights.w2Layered[layerIndex].getScales(), weights.w3Layered[layerIndex].getQuants(), weights.w3Layered[layerIndex].getScales());
64+
weights.rms_att_weightLayered[layerIndex].asFloatArray(),
65+
weights.wqLayered[layerIndex].asByteArray(),
66+
weights.wkLayered[layerIndex].asByteArray(),
67+
weights.wvLayered[layerIndex].asByteArray(),
68+
weights.woLayered[layerIndex].asByteArray(),
69+
weights.rms_ffn_weightLayered[layerIndex].asFloatArray(),
70+
weights.w1Layered[layerIndex].asByteArray(),
71+
weights.w2Layered[layerIndex].asByteArray(),
72+
weights.w3Layered[layerIndex].asByteArray());
6873
unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex);
6974
unifiedLayer.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize);
7075
if (shouldUseFinalNormalization()) {
7176
unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp,
7277
config.dim(), config.rmsNormEps());
7378
}
7479
unifiedLayer.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp)
75-
.task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex].getQuants(),
76-
weights.wqLayered[layerIndex].getScales(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
77-
.task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapK, weights.wkLayered[layerIndex].getQuants(),
78-
weights.wkLayered[layerIndex].getScales(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
79-
.task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapV, weights.wvLayered[layerIndex].getQuants(),
80-
weights.wvLayered[layerIndex].getScales(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
80+
.task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, context, state.wrapXb, state.wrapQ,
81+
weights.wqLayered[layerIndex].asByteArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
82+
.task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, context, state.wrapXb, state.wrapK,
83+
weights.wkLayered[layerIndex].asByteArray(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
84+
.task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, context, state.wrapXb, state.wrapV,
85+
weights.wvLayered[layerIndex].asByteArray(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
8186
.task("rope", TransformerComputeKernelsLayered::ropeRotation, context, state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(), config.headSize())
8287
.task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(),
8388
layerIndex, config.contextLength());
8489
configureAttention(unifiedLayer, layerIndex);
85-
unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapXb, state.wrapX, weights.woLayered[layerIndex].getQuants(),
86-
weights.woLayered[layerIndex].getScales(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
90+
unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, context, state.wrapXb, state.wrapX,
91+
weights.woLayered[layerIndex].asByteArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
8792
.task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize);
8893
if (shouldUseFinalNormalization()) {
8994
unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN,
9095
config.dim(), config.rmsNormEps());
9196
}
9297
unifiedLayer.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN)
93-
.task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex].getQuants(),
94-
weights.w1Layered[layerIndex].getScales(), weights.w3Layered[layerIndex].getQuants(), weights.w3Layered[layerIndex].getScales(), config.dim(), config.hiddenDim(),
98+
.task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivationQ8_0Byte, context, state.wrapXb, state.wrapHb,
99+
weights.w1Layered[layerIndex].asByteArray(), weights.w3Layered[layerIndex].asByteArray(), config.dim(), config.hiddenDim(),
95100
LOCAL_WORK_GROUP_SIZE_ALLOC)
96-
.task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapHb, state.wrapX, weights.w2Layered[layerIndex].getQuants(),
97-
weights.w2Layered[layerIndex].getScales(), config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(state.wrapX);
101+
.task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, context, state.wrapHb, state.wrapX,
102+
weights.w2Layered[layerIndex].asByteArray(), config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(state.wrapX);
98103
return unifiedLayer;
99104
}
100105

0 commit comments

Comments
 (0)