Skip to content

Commit 1a98725

Browse files
committed
Refactor TransformerComputeKernelsLayered: replace matrixVectorRowMajorOptimized logic with matrixVectorRowMajorOptimizedSingle, remove unused floats, and streamline memory allocation and reduction.
1 parent 7c63dc4 commit 1a98725

File tree

1 file changed

+90
-3
lines changed

1 file changed

+90
-3
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java

Lines changed: 90 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ public static void fusedRmsNormFFNGateUp(
6060
float scale = rmsScale.get(0);
6161

6262
// Allocate shared memory for normalized input (reused for both W1 and W3)
63-
float[] xNorm = context.allocateFloatLocalArray(localWorkGroupSize);
6463
float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize);
6564

6665
int rowOffsetW1 = rowId * dim;
@@ -1160,7 +1159,7 @@ public static void matrixVectorGeneric(
11601159
if (rowId >= dim0) {
11611160
return;
11621161
}
1163-
float sum = matrixVectorRowMajorOptimized(context, localSize, x, w, dim1);
1162+
float sum = matrixVectorRowMajorOptimizedSingle(context, localSize, x, w, dim1);
11641163

11651164
// Thread 0 in each workgroup writes the final result
11661165
if (localId == 0) {
@@ -1489,7 +1488,7 @@ public static float matrixVectorRowMajorOptimizedF(KernelContext context, int lo
14891488
return localSum[0];
14901489
}
14911490

1492-
public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize,
1491+
public static float matrixVectorRowMajorOptimizedXX(KernelContext context, int localSize,
14931492
HalfFloatArray x, HalfFloatArray w, int n) {
14941493
int rowId = context.groupIdx;
14951494
int localId = context.localIdx;
@@ -1539,6 +1538,94 @@ public static float matrixVectorRowMajorOptimized(KernelContext context, int loc
15391538
return localSum[0];
15401539
}
15411540

1541+
public static float matrixVectorRowMajorOptimizedSingle(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) {
1542+
int rowId = context.groupIdx;
1543+
int localId = context.localIdx;
1544+
1545+
// Allocate local memory for reduction
1546+
float[] localSum = context.allocateFloatLocalArray(localSize);
1547+
1548+
int rowOffset = rowId * n;
1549+
1550+
HalfFloat partialSum = new HalfFloat(0f);
1551+
for (int j = localId; j < n; j += localSize) {
1552+
int matrixIdx = rowOffset + j;
1553+
HalfFloat mul = HalfFloat.mult(w.get(matrixIdx), x.get(j));
1554+
partialSum = HalfFloat.add(partialSum, mul);
1555+
}
1556+
1557+
1558+
// Store partial sum in local memory
1559+
localSum[localId] = partialSum.getHalfFloatValue();
1560+
context.localBarrier();
1561+
1562+
// Parallel reduction within workgroup
1563+
for (int stride = localSize / 2; stride > 0; stride >>= 1) {
1564+
if (localId < stride) {
1565+
localSum[localId] += localSum[localId + stride];
1566+
}
1567+
context.localBarrier();
1568+
}
1569+
1570+
return localSum[0];
1571+
}
1572+
1573+
public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize,
1574+
HalfFloatArray x, HalfFloatArray w, int n) {
1575+
int rowId = context.groupIdx;
1576+
int localId = context.localIdx;
1577+
float[] localSum = context.allocateFloatLocalArray(localSize);
1578+
1579+
int rowOffset = rowId * n;
1580+
1581+
// Accumulate in HalfFloat to avoid conversions in inner loop
1582+
HalfFloat sum0 = new HalfFloat(0f);
1583+
HalfFloat sum1 = new HalfFloat(0f);
1584+
HalfFloat sum2 = new HalfFloat(0f);
1585+
HalfFloat sum3 = new HalfFloat(0f);
1586+
1587+
int stride = localSize;
1588+
int stride2 = localSize << 1;
1589+
int stride3 = localSize * 3;
1590+
int stride4 = localSize << 2;
1591+
1592+
int j = localId;
1593+
int limit = n - stride3;
1594+
1595+
for (; j < limit; j += stride4) {
1596+
int base = rowOffset + j;
1597+
1598+
// Stay in HalfFloat - no getFloat32() calls
1599+
HalfFloat x0 = x.get(j);
1600+
HalfFloat x1 = x.get(j + stride);
1601+
HalfFloat x2 = x.get(j + stride2);
1602+
HalfFloat x3 = x.get(j + stride3);
1603+
1604+
sum0 = HalfFloat.add(sum0, HalfFloat.mult(w.get(base), x0));
1605+
sum1 = HalfFloat.add(sum1, HalfFloat.mult(w.get(base + stride), x1));
1606+
sum2 = HalfFloat.add(sum2, HalfFloat.mult(w.get(base + stride2), x2));
1607+
sum3 = HalfFloat.add(sum3, HalfFloat.mult(w.get(base + stride3), x3));
1608+
}
1609+
1610+
// Cleanup loop
1611+
for (; j < n; j += stride) {
1612+
sum0 = HalfFloat.add(sum0, HalfFloat.mult(w.get(rowOffset + j), x.get(j)));
1613+
}
1614+
1615+
// Convert to float32 only at the end for reduction
1616+
localSum[localId] = sum0.getFloat32() + sum1.getFloat32() + sum2.getFloat32() + sum3.getFloat32();
1617+
context.localBarrier();
1618+
1619+
for (int s = localSize >> 1; s > 0; s >>= 1) {
1620+
if (localId < s) {
1621+
localSum[localId] += localSum[localId + s];
1622+
}
1623+
context.localBarrier();
1624+
}
1625+
1626+
return localSum[0];
1627+
}
1628+
15421629
public static void fusedQKVMatmul(
15431630
KernelContext context,
15441631
HalfFloatArray x, // input (read once!)

0 commit comments

Comments
 (0)