@@ -60,7 +60,6 @@ public static void fusedRmsNormFFNGateUp(
6060 float scale = rmsScale .get (0 );
6161
6262 // Allocate shared memory for normalized input (reused for both W1 and W3)
63- float [] xNorm = context .allocateFloatLocalArray (localWorkGroupSize );
6463 float [] localSum = context .allocateFloatLocalArray (localWorkGroupSize );
6564
6665 int rowOffsetW1 = rowId * dim ;
@@ -1160,7 +1159,7 @@ public static void matrixVectorGeneric(
11601159 if (rowId >= dim0 ) {
11611160 return ;
11621161 }
1163- float sum = matrixVectorRowMajorOptimized (context , localSize , x , w , dim1 );
1162+ float sum = matrixVectorRowMajorOptimizedSingle (context , localSize , x , w , dim1 );
11641163
11651164 // Thread 0 in each workgroup writes the final result
11661165 if (localId == 0 ) {
@@ -1489,7 +1488,7 @@ public static float matrixVectorRowMajorOptimizedF(KernelContext context, int lo
14891488 return localSum [0 ];
14901489 }
14911490
1492- public static float matrixVectorRowMajorOptimized (KernelContext context , int localSize ,
1491+ public static float matrixVectorRowMajorOptimizedXX (KernelContext context , int localSize ,
14931492 HalfFloatArray x , HalfFloatArray w , int n ) {
14941493 int rowId = context .groupIdx ;
14951494 int localId = context .localIdx ;
@@ -1539,6 +1538,94 @@ public static float matrixVectorRowMajorOptimized(KernelContext context, int loc
15391538 return localSum [0 ];
15401539 }
15411540
1541+ public static float matrixVectorRowMajorOptimizedSingle (KernelContext context , int localSize , HalfFloatArray x , HalfFloatArray w , int n ) {
1542+ int rowId = context .groupIdx ;
1543+ int localId = context .localIdx ;
1544+
1545+ // Allocate local memory for reduction
1546+ float [] localSum = context .allocateFloatLocalArray (localSize );
1547+
1548+ int rowOffset = rowId * n ;
1549+
1550+ HalfFloat partialSum = new HalfFloat (0f );
1551+ for (int j = localId ; j < n ; j += localSize ) {
1552+ int matrixIdx = rowOffset + j ;
1553+ HalfFloat mul = HalfFloat .mult (w .get (matrixIdx ), x .get (j ));
1554+ partialSum = HalfFloat .add (partialSum , mul );
1555+ }
1556+
1557+
1558+ // Store partial sum in local memory
1559+ localSum [localId ] = partialSum .getHalfFloatValue ();
1560+ context .localBarrier ();
1561+
1562+ // Parallel reduction within workgroup
1563+ for (int stride = localSize / 2 ; stride > 0 ; stride >>= 1 ) {
1564+ if (localId < stride ) {
1565+ localSum [localId ] += localSum [localId + stride ];
1566+ }
1567+ context .localBarrier ();
1568+ }
1569+
1570+ return localSum [0 ];
1571+ }
1572+
1573+ public static float matrixVectorRowMajorOptimized (KernelContext context , int localSize ,
1574+ HalfFloatArray x , HalfFloatArray w , int n ) {
1575+ int rowId = context .groupIdx ;
1576+ int localId = context .localIdx ;
1577+ float [] localSum = context .allocateFloatLocalArray (localSize );
1578+
1579+ int rowOffset = rowId * n ;
1580+
1581+ // Accumulate in HalfFloat to avoid conversions in inner loop
1582+ HalfFloat sum0 = new HalfFloat (0f );
1583+ HalfFloat sum1 = new HalfFloat (0f );
1584+ HalfFloat sum2 = new HalfFloat (0f );
1585+ HalfFloat sum3 = new HalfFloat (0f );
1586+
1587+ int stride = localSize ;
1588+ int stride2 = localSize << 1 ;
1589+ int stride3 = localSize * 3 ;
1590+ int stride4 = localSize << 2 ;
1591+
1592+ int j = localId ;
1593+ int limit = n - stride3 ;
1594+
1595+ for (; j < limit ; j += stride4 ) {
1596+ int base = rowOffset + j ;
1597+
1598+ // Stay in HalfFloat - no getFloat32() calls
1599+ HalfFloat x0 = x .get (j );
1600+ HalfFloat x1 = x .get (j + stride );
1601+ HalfFloat x2 = x .get (j + stride2 );
1602+ HalfFloat x3 = x .get (j + stride3 );
1603+
1604+ sum0 = HalfFloat .add (sum0 , HalfFloat .mult (w .get (base ), x0 ));
1605+ sum1 = HalfFloat .add (sum1 , HalfFloat .mult (w .get (base + stride ), x1 ));
1606+ sum2 = HalfFloat .add (sum2 , HalfFloat .mult (w .get (base + stride2 ), x2 ));
1607+ sum3 = HalfFloat .add (sum3 , HalfFloat .mult (w .get (base + stride3 ), x3 ));
1608+ }
1609+
1610+ // Cleanup loop
1611+ for (; j < n ; j += stride ) {
1612+ sum0 = HalfFloat .add (sum0 , HalfFloat .mult (w .get (rowOffset + j ), x .get (j )));
1613+ }
1614+
1615+ // Convert to float32 only at the end for reduction
1616+ localSum [localId ] = sum0 .getFloat32 () + sum1 .getFloat32 () + sum2 .getFloat32 () + sum3 .getFloat32 ();
1617+ context .localBarrier ();
1618+
1619+ for (int s = localSize >> 1 ; s > 0 ; s >>= 1 ) {
1620+ if (localId < s ) {
1621+ localSum [localId ] += localSum [localId + s ];
1622+ }
1623+ context .localBarrier ();
1624+ }
1625+
1626+ return localSum [0 ];
1627+ }
1628+
15421629 public static void fusedQKVMatmul (
15431630 KernelContext context ,
15441631 HalfFloatArray x , // input (read once!)
0 commit comments