Skip to content

Commit ed74652

Browse files
committed
Replace splitQKV kernel with fusedRmsNormQKVMatmulDirect, refactor Phi3 FP16 FFN layers to consolidate QKV projection tasks, and update worker grid/task configurations.
1 parent 6c1ac6f commit ed74652

File tree

2 files changed

+132
-22
lines changed

2 files changed

+132
-22
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/kernels/Phi3Kernels.java

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,4 +192,95 @@ public static void ropeRotationWithCacheCopyPhi3(
192192
valueCache.set(cacheOffset + base + idx + dimHalf, sv.get(base + idx + dimHalf));
193193
}
194194
}
195+
196+
/**
197+
* Fused RMSNorm apply + QKV projection with direct output to separate Q, K, V buffers.
198+
*
199+
* <p>Eliminates the need for a separate splitQKV kernel by routing outputs
200+
* directly based on row index:</p>
201+
* <ul>
202+
* <li>Rows [0, dim): Q projection</li>
203+
* <li>Rows [dim, dim+kvDim): K projection</li>
204+
* <li>Rows [dim+kvDim, dim+2*kvDim): V projection</li>
205+
* </ul>
206+
*
207+
* <p>Formula: output[row] = sum_j(Wqkv[row,j] * rmsWeight[j] * scale * x[j])</p>
208+
*
209+
* @param context Kernel execution context
210+
* @param x Input hidden state (FP32) [dim]
211+
* @param q Output Q buffer (FP32) [dim]
212+
* @param k Output K buffer (FP32) [kvDim]
213+
* @param v Output V buffer (FP32) [kvDim]
214+
* @param rmsWeights RMS normalization weights (FP32) [dim]
215+
* @param rmsScale Precomputed RMS scale factor [1]
216+
* @param wqkv Combined QKV weight matrix (FP16) [opSize × dim]
217+
* @param dim Model dimension (Q output size)
218+
* @param kvDim KV dimension (K/V output size)
219+
* @param localWorkGroupSize Local work group size for reduction
220+
*/
221+
public static void fusedRmsNormQKVMatmulDirect(
222+
KernelContext context,
223+
FloatArray x, // input (FP32)
224+
FloatArray q, // output Q (FP32)
225+
FloatArray k, // output K (FP32)
226+
FloatArray v, // output V (FP32)
227+
FloatArray rmsWeights, // RMS norm weights
228+
FloatArray rmsScale, // temp[0] = scale factor
229+
HalfFloatArray wqkv, // combined QKV weight matrix
230+
int dim, // input dim and Q output dim
231+
int kvDim, // K/V output dim
232+
int localWorkGroupSize) {
233+
234+
int rowId = context.groupIdx;
235+
int localId = context.localIdx;
236+
237+
// Total rows = dim (Q) + kvDim (K) + kvDim (V)
238+
int totalRows = dim + 2 * kvDim;
239+
if (rowId >= totalRows) {
240+
return;
241+
}
242+
243+
float scale = rmsScale.get(0);
244+
245+
// Allocate shared memory for reduction
246+
float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize);
247+
248+
int rowOffset = rowId * dim;
249+
250+
// Each thread computes partial dot product with inline normalization
251+
float partialSum = 0.0f;
252+
for (int j = localId; j < dim; j += localWorkGroupSize) {
253+
float normalized = rmsWeights.get(j) * scale * x.get(j);
254+
partialSum += wqkv.get(rowOffset + j).getFloat32() * normalized;
255+
}
256+
257+
localSum[localId] = partialSum;
258+
context.localBarrier();
259+
260+
// Parallel reduction within workgroup
261+
for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
262+
if (localId < stride) {
263+
localSum[localId] += localSum[localId + stride];
264+
}
265+
context.localBarrier();
266+
}
267+
268+
// Thread 0 writes to appropriate output buffer
269+
if (localId == 0) {
270+
float result = localSum[0];
271+
272+
if (rowId < dim) {
273+
// Q projection: rows [0, dim)
274+
q.set(rowId, result);
275+
} else if (rowId < dim + kvDim) {
276+
// K projection: rows [dim, dim+kvDim)
277+
int kIdx = rowId - dim;
278+
k.set(kIdx, result);
279+
} else {
280+
// V projection: rows [dim+kvDim, dim+2*kvDim)
281+
int vIdx = rowId - dim - kvDim;
282+
v.set(vIdx, result);
283+
}
284+
}
285+
}
195286
}

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,17 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) {
7979
// FFN down projection worker
8080
int ffnDownGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
8181
WorkerGrid ffnDownWorker = WorkerGridFactory.genericWorker(ffnDownGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC);
82+
// Same worker as before - total rows = dim + 2*kvDim = opSize
8283

84+
// Remove: gridScheduler.addWorkerGrid("layer_" + i + ".splitQKV", splitQKVWorker);
8385
// Map workers to tasks for each layer (in task execution order)
8486
for (int i = 0; i < config.numberOfLayers(); i++) {
8587
// === Attention Block ===
8688
gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker);
87-
gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_qkv_matmul", fusedQkvWorker);
88-
gridScheduler.addWorkerGrid("layer_" + i + ".splitQKV", splitQKVWorker);
89+
// gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_qkv_matmul", fusedQkvWorker);
90+
gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_qkv_projection", fusedQkvWorker);
91+
92+
// gridScheduler.addWorkerGrid("layer_" + i + ".splitQKV", splitQKVWorker);
8993
gridScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWorker);
9094
gridScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker);
9195
gridScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", matmul1Worker);
@@ -257,29 +261,44 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) {
257261
phi3Config.rmsNormEps(), // epsilon
258262
phi3State.localSize); // local memory size
259263

260-
// Fused RMS Apply + QKV Projection (combined matrix)
261-
unifiedLayer.task("attn_rms_qkv_matmul",
262-
Phi3Kernels::fusedRmsNormMatmul,
264+
// // Fused RMS Apply + QKV Projection (combined matrix)
265+
// unifiedLayer.task("attn_rms_qkv_matmul",
266+
// Phi3Kernels::fusedRmsNormMatmul,
267+
// context,
268+
// phi3State.wrapX, // input: raw hidden state (FP32)
269+
// phi3State.wrapQkv, // output: combined Q+K+V
270+
// weights.rms_att_weightLayered[layerIndex].asFloatArray(), // RMS weights
271+
// phi3State.temp, // RMS scale factor from reduction
272+
// weights.wqkvLayered[layerIndex].asHalfFloatArray(), // Wqkv [opSize × dim]
273+
// phi3Config.dim(), // input dimension
274+
// opSize, // output dimension (Q + K + V)
275+
// LOCAL_WORK_GROUP_SIZE_ALLOC);
276+
//
277+
// // Split combined QKV into separate Q, K, V buffers
278+
// unifiedLayer.task("splitQKV",
279+
// TransformerComputeKernelsLayered::splitQKV,
280+
// phi3State.wrapQkv,
281+
// phi3State.wrapQ,
282+
// phi3State.wrapK,
283+
// phi3State.wrapV,
284+
// phi3Config.dim(),
285+
// phi3Config.headSize() * phi3Config.numberOfKeyValueHeads());
286+
287+
// AFTER: 1 task
288+
unifiedLayer.task("attn_rms_qkv_projection",
289+
Phi3Kernels::fusedRmsNormQKVMatmulDirect,
263290
context,
264-
phi3State.wrapX, // input: raw hidden state (FP32)
265-
phi3State.wrapQkv, // output: combined Q+K+V
266-
weights.rms_att_weightLayered[layerIndex].asFloatArray(), // RMS weights
267-
phi3State.temp, // RMS scale factor from reduction
268-
weights.wqkvLayered[layerIndex].asHalfFloatArray(), // Wqkv [opSize × dim]
269-
phi3Config.dim(), // input dimension
270-
opSize, // output dimension (Q + K + V)
291+
phi3State.wrapX, // input
292+
phi3State.wrapQ, // output Q
293+
phi3State.wrapK, // output K
294+
phi3State.wrapV, // output V
295+
weights.rms_att_weightLayered[layerIndex].asFloatArray(),
296+
phi3State.temp, // RMS scale
297+
weights.wqkvLayered[layerIndex].asHalfFloatArray(),
298+
phi3Config.dim(), // dim
299+
phi3Config.kvDim(), // kvDim
271300
LOCAL_WORK_GROUP_SIZE_ALLOC);
272301

273-
// Split combined QKV into separate Q, K, V buffers
274-
unifiedLayer.task("splitQKV",
275-
TransformerComputeKernelsLayered::splitQKV,
276-
phi3State.wrapQkv,
277-
phi3State.wrapQ,
278-
phi3State.wrapK,
279-
phi3State.wrapV,
280-
phi3Config.dim(),
281-
phi3Config.headSize() * phi3Config.numberOfKeyValueHeads());
282-
283302
// Fused Phi3 RoPE Rotation + KV Cache Write
284303
unifiedLayer.task("rope_and_kv_cache",
285304
Phi3Kernels::ropeRotationWithCacheCopyPhi3,

0 commit comments

Comments
 (0)