@@ -1128,6 +1128,26 @@ public static void matrixVectorGenericWithResidual(KernelContext context, FloatA
11281128 }
11291129 }
11301130
1131+ public static void matrixVectorGenericWithResidualQ8_0Byte (KernelContext context , FloatArray x , FloatArray hb , ByteArray w , int n , int d , int localWorkGroupSize ) {
1132+ // One row per workgroup (not per thread)
1133+ int rowId = context .groupIdx ;
1134+ int localId = context .localIdx ;
1135+ int localSize = localWorkGroupSize ;
1136+
1137+ // Early exit if this workgroup is beyond our output dimension
1138+ if (rowId >= d ) {
1139+ return ;
1140+ }
1141+
1142+ float sum = matrixVectorRowMajorOptimizedQ8_0Byte (context , localSize , x , w , n );
1143+
1144+ // Thread 0 in each workgroup writes the final result
1145+ if (localId == 0 ) {
1146+ float result = hb .get (rowId ) + sum ;
1147+ hb .set (rowId , result );
1148+ }
1149+ }
1150+
11311151 public static void fusedFeedForwardWithSiLUAndGLUActivation (KernelContext context , FloatArray x , FloatArray hb , Int8Array w1_quants , HalfFloatArray w1_scales , Int8Array w3_quants ,
11321152 HalfFloatArray w3_scales , int n , int d , int localWorkGroupSize ) {
11331153 // One row per workgroup (not per thread)
@@ -1149,6 +1169,29 @@ public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext contex
11491169 }
11501170 }
11511171
1172+ public static void fusedFeedForwardWithSiLUAndGLUActivationQ8_0Byte (KernelContext context , FloatArray x , FloatArray hb ,
1173+ ByteArray w1 ,
1174+ ByteArray w3 ,
1175+ int n , int d , int localWorkGroupSize ) {
1176+ // One row per workgroup (not per thread)
1177+ int rowId = context .groupIdx ;
1178+ int localId = context .localIdx ;
1179+
1180+ if (rowId >= d ) {
1181+ return ;
1182+ }
1183+
1184+ float sum1 = matrixVectorRowMajorOptimizedQ8_0Byte (context , localWorkGroupSize , x , w1 , n );
1185+ float sum3 = matrixVectorRowMajorOptimizedQ8_0Byte (context , localWorkGroupSize , x , w3 , n );
1186+
1187+ // Thread 0 in each workgroup writes the final result
1188+ if (localId == 0 ) {
1189+ float silu = siluActivation (sum1 ); // Using the new SiLU method
1190+ float result = silu * sum3 ;
1191+ hb .set (rowId , result );
1192+ }
1193+ }
1194+
11521195 /**
11531196 * Orchestrates parallel multi-head attention computation across all heads. Each head processes attention independently in parallel.
11541197 *
0 commit comments