@@ -61,40 +61,45 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config,
6161 unifiedLayer .consumeFromDevice (state .wrapX );
6262 unifiedLayer .transferToDevice (DataTransferMode .FIRST_EXECUTION ,
6363 //Copy-in weights per layer for batched-layered layout
64- weights .rms_att_weightLayered [layerIndex ].asFloatArray (), weights .wqLayered [layerIndex ].getQuants (), weights .wqLayered [layerIndex ].getScales (), weights .wkLayered [layerIndex ].getQuants (),
65- weights .wkLayered [layerIndex ].getScales (), weights .wvLayered [layerIndex ].getQuants (), weights .wvLayered [layerIndex ].getScales (), weights .woLayered [layerIndex ].getQuants (),
66- weights .woLayered [layerIndex ].getScales (), weights .rms_ffn_weightLayered [layerIndex ].asFloatArray (), weights .w1Layered [layerIndex ].getQuants (), weights .w1Layered [layerIndex ].getScales (),
67- weights .w2Layered [layerIndex ].getQuants (), weights .w2Layered [layerIndex ].getScales (), weights .w3Layered [layerIndex ].getQuants (), weights .w3Layered [layerIndex ].getScales ());
64+ weights .rms_att_weightLayered [layerIndex ].asFloatArray (),
65+ weights .wqLayered [layerIndex ].asByteArray (),
66+ weights .wkLayered [layerIndex ].asByteArray (),
67+ weights .wvLayered [layerIndex ].asByteArray (),
68+ weights .woLayered [layerIndex ].asByteArray (),
69+ weights .rms_ffn_weightLayered [layerIndex ].asFloatArray (),
70+ weights .w1Layered [layerIndex ].asByteArray (),
71+ weights .w2Layered [layerIndex ].asByteArray (),
72+ weights .w3Layered [layerIndex ].asByteArray ());
6873 unifiedLayer = configureLayerDataTransfers (unifiedLayer , layerIndex );
6974 unifiedLayer .task ("reductionsOneBlock" , TransformerComputeKernelsLayered ::reductionOneBlockWithLayer , context , state .temp , state .wrapX , config .dim (), config .rmsNormEps (), state .localSize );
7075 if (shouldUseFinalNormalization ()) {
7176 unifiedLayer .task ("reductionFinalNormalization" , TransformerComputeKernelsLayered ::reductionFinalNormalization , context , state .temp ,
7277 config .dim (), config .rmsNormEps ());
7378 }
7479 unifiedLayer .task ("mapContext" , TransformerComputeKernelsLayered ::reductionOneBlock2WithLayer , context , state .wrapXb , state .wrapX , weights .rms_att_weightLayered [layerIndex ].asFloatArray (), state .temp )
75- .task ("qmatmul" , TransformerComputeKernelsLayered ::matrixVectorGeneric , context , state .wrapXb , state .wrapQ , weights . wqLayered [ layerIndex ]. getQuants () ,
76- weights .wqLayered [layerIndex ].getScales (), config .dim (), config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
77- .task ("kmatmul" , TransformerComputeKernelsLayered ::matrixVectorGeneric , context , state .wrapXb , state .wrapK , weights . wkLayered [ layerIndex ]. getQuants () ,
78- weights .wkLayered [layerIndex ].getScales (), config .dim (), config .kvDim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
79- .task ("vmatmul" , TransformerComputeKernelsLayered ::matrixVectorGeneric , context , state .wrapXb , state .wrapV , weights . wvLayered [ layerIndex ]. getQuants () ,
80- weights .wvLayered [layerIndex ].getScales (), config .dim (), config .kvDim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
80+ .task ("qmatmul" , TransformerComputeKernelsLayered ::matrixVectorGenericQ8Byte , context , state .wrapXb , state .wrapQ ,
81+ weights .wqLayered [layerIndex ].asByteArray (), config .dim (), config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
82+ .task ("kmatmul" , TransformerComputeKernelsLayered ::matrixVectorGenericQ8Byte , context , state .wrapXb , state .wrapK ,
83+ weights .wkLayered [layerIndex ].asByteArray (), config .dim (), config .kvDim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
84+ .task ("vmatmul" , TransformerComputeKernelsLayered ::matrixVectorGenericQ8Byte , context , state .wrapXb , state .wrapV ,
85+ weights .wvLayered [layerIndex ].asByteArray (), config .dim (), config .kvDim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
8186 .task ("rope" , TransformerComputeKernelsLayered ::ropeRotation , context , state .positionHolder , state .wrapQ , state .wrapK , config .kvDim (), config .headSize ())
8287 .task ("copyToCaches" , TransformerComputeKernelsLayered ::copyToCache , state .wrapKeyCache , state .wrapK , state .wrapValueCache , state .wrapV , state .positionHolder , config .kvDim (),
8388 layerIndex , config .contextLength ());
8489 configureAttention (unifiedLayer , layerIndex );
85- unifiedLayer .task ("matmul1" , TransformerComputeKernelsLayered ::matrixVectorGenericWithResidual , context , state .wrapXb , state .wrapX , weights . woLayered [ layerIndex ]. getQuants () ,
86- weights .woLayered [layerIndex ].getScales (), config .dim (), config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
90+ unifiedLayer .task ("matmul1" , TransformerComputeKernelsLayered ::matrixVectorGenericWithResidualQ8_0Byte , context , state .wrapXb , state .wrapX ,
91+ weights .woLayered [layerIndex ].asByteArray (), config .dim (), config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
8792 .task ("reductionsOneBlockFFN" , TransformerComputeKernelsLayered ::reductionOneBlockWithLayer , context , state .tempFFN , state .wrapX , config .dim (), config .rmsNormEps (), state .localSize );
8893 if (shouldUseFinalNormalization ()) {
8994 unifiedLayer .task ("reductionFinalNormalizationFFN" , TransformerComputeKernelsLayered ::reductionFinalNormalization , context , state .tempFFN ,
9095 config .dim (), config .rmsNormEps ());
9196 }
9297 unifiedLayer .task ("mapContextFFN" , TransformerComputeKernelsLayered ::reductionOneBlock2WithLayer , context , state .wrapXb , state .wrapX , weights .rms_ffn_weightLayered [layerIndex ].asFloatArray (), state .tempFFN )
93- .task ("fused_ffn_w1_w3" , TransformerComputeKernelsLayered ::fusedFeedForwardWithSiLUAndGLUActivation , context , state .wrapXb , state .wrapHb , weights . w1Layered [ layerIndex ]. getQuants () ,
94- weights .w1Layered [layerIndex ].getScales (), weights .w3Layered [layerIndex ].getQuants (), weights . w3Layered [ layerIndex ]. getScales (), config .dim (), config .hiddenDim (),
98+ .task ("fused_ffn_w1_w3" , TransformerComputeKernelsLayered ::fusedFeedForwardWithSiLUAndGLUActivationQ8_0Byte , context , state .wrapXb , state .wrapHb ,
99+ weights .w1Layered [layerIndex ].asByteArray (), weights .w3Layered [layerIndex ].asByteArray (), config .dim (), config .hiddenDim (),
95100 LOCAL_WORK_GROUP_SIZE_ALLOC )
96- .task ("projectionTwo" , TransformerComputeKernelsLayered ::matrixVectorGenericWithResidual , context , state .wrapHb , state .wrapX , weights . w2Layered [ layerIndex ]. getQuants () ,
97- weights .w2Layered [layerIndex ].getScales (), config .hiddenDim (), config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC ).persistOnDevice (state .wrapX );
101+ .task ("projectionTwo" , TransformerComputeKernelsLayered ::matrixVectorGenericWithResidualQ8_0Byte , context , state .wrapHb , state .wrapX ,
102+ weights .w2Layered [layerIndex ].asByteArray (), config .hiddenDim (), config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC ).persistOnDevice (state .wrapX );
98103 return unifiedLayer ;
99104 }
100105
0 commit comments