Skip to content

Commit 2316ca1

Browse files
Update LogitsQ8_0Layer to use byte-based Q8_0 kernels
1 parent dd8064e commit 2316ca1

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,15 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
5858
private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) {
5959
TaskGraph logits = new TaskGraph("logits");
6060
logits.consumeFromDevice(lastTaskGraphID, state.wrapX).transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits)
61-
.transferToDevice(DataTransferMode.FIRST_EXECUTION, context, state.wrapLogits, weights.wclsByteArray.getQuants(), weights.wclsByteArray.getScales(),
61+
.transferToDevice(DataTransferMode.FIRST_EXECUTION, context, state.wrapLogits, weights.wclsByteArray.asByteArray(),
6262
weights.rms_final_weight_as_floatArray)
6363
.task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize);
6464
if (schedulerType == SchedulerType.NON_NVIDIA) {
6565
logits.task("reductionFinalNormalizationLogits", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits, config.dim(), config.rmsNormEps());
6666
}
6767
logits.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits)
68-
.task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, //
69-
context, state.wrapX, state.wrapLogits, weights.wclsByteArray.getQuants(), weights.wclsByteArray.getScales(), //
68+
.task("projection", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, //
69+
context, state.wrapX, state.wrapLogits, weights.wclsByteArray.asByteArray(), //
7070
config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS) //
7171
.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits);
7272
return logits;

0 commit comments

Comments
 (0)