Skip to content

Commit 9d0fb16

Browse files
Add matrixVectorGenericWithResidualQ8_0Byte and fusedFeedForwardWithSiLUAndGLUActivationQ8_0Byte kernels for byte-based Q8_0 computations
1 parent 56a960a commit 9d0fb16

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,6 +1128,26 @@ public static void matrixVectorGenericWithResidual(KernelContext context, FloatA
11281128
}
11291129
}
11301130

1131+
public static void matrixVectorGenericWithResidualQ8_0Byte(KernelContext context, FloatArray x, FloatArray hb, ByteArray w, int n, int d, int localWorkGroupSize) {
1132+
// One row per workgroup (not per thread)
1133+
int rowId = context.groupIdx;
1134+
int localId = context.localIdx;
1135+
int localSize = localWorkGroupSize;
1136+
1137+
// Early exit if this workgroup is beyond our output dimension
1138+
if (rowId >= d) {
1139+
return;
1140+
}
1141+
1142+
float sum = matrixVectorRowMajorOptimizedQ8_0Byte(context, localSize, x, w, n);
1143+
1144+
// Thread 0 in each workgroup writes the final result
1145+
if (localId == 0) {
1146+
float result = hb.get(rowId) + sum;
1147+
hb.set(rowId, result);
1148+
}
1149+
}
1150+
11311151
public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext context, FloatArray x, FloatArray hb, Int8Array w1_quants, HalfFloatArray w1_scales, Int8Array w3_quants,
11321152
HalfFloatArray w3_scales, int n, int d, int localWorkGroupSize) {
11331153
// One row per workgroup (not per thread)
@@ -1149,6 +1169,29 @@ public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext contex
11491169
}
11501170
}
11511171

1172+
public static void fusedFeedForwardWithSiLUAndGLUActivationQ8_0Byte(KernelContext context, FloatArray x, FloatArray hb,
1173+
ByteArray w1,
1174+
ByteArray w3,
1175+
int n, int d, int localWorkGroupSize) {
1176+
// One row per workgroup (not per thread)
1177+
int rowId = context.groupIdx;
1178+
int localId = context.localIdx;
1179+
1180+
if (rowId >= d) {
1181+
return;
1182+
}
1183+
1184+
float sum1 = matrixVectorRowMajorOptimizedQ8_0Byte(context, localWorkGroupSize, x, w1, n);
1185+
float sum3 = matrixVectorRowMajorOptimizedQ8_0Byte(context, localWorkGroupSize, x, w3, n);
1186+
1187+
// Thread 0 in each workgroup writes the final result
1188+
if (localId == 0) {
1189+
float silu = siluActivation(sum1); // Using the new SiLU method
1190+
float result = silu * sum3;
1191+
hb.set(rowId, result);
1192+
}
1193+
}
1194+
11521195
/**
11531196
* Orchestrates parallel multi-head attention computation across all heads. Each head processes attention independently in parallel.
11541197
*

0 commit comments

Comments
 (0)