@@ -58,15 +58,15 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
5858 private TaskGraph setupLogitsTaskGraph (TornadoWeights weights , Configuration config ) {
5959 TaskGraph logits = new TaskGraph ("logits" );
6060 logits .consumeFromDevice (lastTaskGraphID , state .wrapX ).transferToDevice (DataTransferMode .EVERY_EXECUTION , state .tempLogits )
61- .transferToDevice (DataTransferMode .FIRST_EXECUTION , context , state .wrapLogits , weights .wclsByteArray .getQuants (), weights . wclsByteArray . getScales (),
61+ .transferToDevice (DataTransferMode .FIRST_EXECUTION , context , state .wrapLogits , weights .wclsByteArray .asByteArray (),
6262 weights .rms_final_weight_as_floatArray )
6363 .task ("reductionsOneBlockLogits" , TransformerComputeKernels ::reductionOneBlockWithLayer , context , state .tempLogits , state .wrapX , config .dim (), config .rmsNormEps (), state .localSize );
6464 if (schedulerType == SchedulerType .NON_NVIDIA ) {
6565 logits .task ("reductionFinalNormalizationLogits" , TransformerComputeKernelsLayered ::reductionFinalNormalization , context , state .tempLogits , config .dim (), config .rmsNormEps ());
6666 }
6767 logits .task ("mapContextLogits" , TransformerComputeKernels ::reductionOneBlock2WithLogits , context , state .wrapX , weights .rms_final_weight_as_floatArray .asFloatArray (), state .tempLogits )
68- .task ("projection" , TransformerComputeKernelsLayered ::matrixVectorGeneric , //
69- context , state .wrapX , state .wrapLogits , weights .wclsByteArray .getQuants (), weights . wclsByteArray . getScales (), //
68+ .task ("projection" , TransformerComputeKernelsLayered ::matrixVectorGenericQ8Byte , //
69+ context , state .wrapX , state .wrapLogits , weights .wclsByteArray .asByteArray (), //
7070 config .dim (), config .vocabularySize (), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS ) //
7171 .transferToHost (DataTransferMode .EVERY_EXECUTION , state .wrapLogits );
7272 return logits ;
0 commit comments