@@ -79,13 +79,17 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) {
7979 // FFN down projection worker
8080 int ffnDownGlobal = config .dim () * LOCAL_WORK_GROUP_SIZE_ALLOC ;
8181 WorkerGrid ffnDownWorker = WorkerGridFactory .genericWorker (ffnDownGlobal , LOCAL_WORK_GROUP_SIZE_ALLOC );
82+ // Same worker as before - total rows = dim + 2*kvDim = opSize
8283
84+ // Remove: gridScheduler.addWorkerGrid("layer_" + i + ".splitQKV", splitQKVWorker);
8385 // Map workers to tasks for each layer (in task execution order)
8486 for (int i = 0 ; i < config .numberOfLayers (); i ++) {
8587 // === Attention Block ===
8688 gridScheduler .addWorkerGrid ("layer_" + i + ".attn_rms_reduce" , rmsNormWorker );
87- gridScheduler .addWorkerGrid ("layer_" + i + ".attn_rms_qkv_matmul" , fusedQkvWorker );
88- gridScheduler .addWorkerGrid ("layer_" + i + ".splitQKV" , splitQKVWorker );
89+ // gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_qkv_matmul", fusedQkvWorker);
90+ gridScheduler .addWorkerGrid ("layer_" + i + ".attn_rms_qkv_projection" , fusedQkvWorker );
91+
92+ // gridScheduler.addWorkerGrid("layer_" + i + ".splitQKV", splitQKVWorker);
8993 gridScheduler .addWorkerGrid ("layer_" + i + ".rope_and_kv_cache" , ropeWorker );
9094 gridScheduler .addWorkerGrid ("layer_" + i + ".attention" , parallelAttentionWorker );
9195 gridScheduler .addWorkerGrid ("layer_" + i + ".attn_output_proj" , matmul1Worker );
@@ -257,29 +261,44 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) {
257261 phi3Config .rmsNormEps (), // epsilon
258262 phi3State .localSize ); // local memory size
259263
260- // Fused RMS Apply + QKV Projection (combined matrix)
261- unifiedLayer .task ("attn_rms_qkv_matmul" ,
262- Phi3Kernels ::fusedRmsNormMatmul ,
264+ // // Fused RMS Apply + QKV Projection (combined matrix)
265+ // unifiedLayer.task("attn_rms_qkv_matmul",
266+ // Phi3Kernels::fusedRmsNormMatmul,
267+ // context,
268+ // phi3State.wrapX, // input: raw hidden state (FP32)
269+ // phi3State.wrapQkv, // output: combined Q+K+V
270+ // weights.rms_att_weightLayered[layerIndex].asFloatArray(), // RMS weights
271+ // phi3State.temp, // RMS scale factor from reduction
272+ // weights.wqkvLayered[layerIndex].asHalfFloatArray(), // Wqkv [opSize × dim]
273+ // phi3Config.dim(), // input dimension
274+ // opSize, // output dimension (Q + K + V)
275+ // LOCAL_WORK_GROUP_SIZE_ALLOC);
276+ //
277+ // // Split combined QKV into separate Q, K, V buffers
278+ // unifiedLayer.task("splitQKV",
279+ // TransformerComputeKernelsLayered::splitQKV,
280+ // phi3State.wrapQkv,
281+ // phi3State.wrapQ,
282+ // phi3State.wrapK,
283+ // phi3State.wrapV,
284+ // phi3Config.dim(),
285+ // phi3Config.headSize() * phi3Config.numberOfKeyValueHeads());
286+
287+ // AFTER: 1 task
288+ unifiedLayer .task ("attn_rms_qkv_projection" ,
289+ Phi3Kernels ::fusedRmsNormQKVMatmulDirect ,
263290 context ,
264- phi3State .wrapX , // input: raw hidden state (FP32)
265- phi3State .wrapQkv , // output: combined Q+K+V
266- weights .rms_att_weightLayered [layerIndex ].asFloatArray (), // RMS weights
267- phi3State .temp , // RMS scale factor from reduction
268- weights .wqkvLayered [layerIndex ].asHalfFloatArray (), // Wqkv [opSize × dim]
269- phi3Config .dim (), // input dimension
270- opSize , // output dimension (Q + K + V)
291+ phi3State .wrapX , // input
292+ phi3State .wrapQ , // output Q
293+ phi3State .wrapK , // output K
294+ phi3State .wrapV , // output V
295+ weights .rms_att_weightLayered [layerIndex ].asFloatArray (),
296+ phi3State .temp , // RMS scale
297+ weights .wqkvLayered [layerIndex ].asHalfFloatArray (),
298+ phi3Config .dim (), // dim
299+ phi3Config .kvDim (), // kvDim
271300 LOCAL_WORK_GROUP_SIZE_ALLOC );
272301
273- // Split combined QKV into separate Q, K, V buffers
274- unifiedLayer .task ("splitQKV" ,
275- TransformerComputeKernelsLayered ::splitQKV ,
276- phi3State .wrapQkv ,
277- phi3State .wrapQ ,
278- phi3State .wrapK ,
279- phi3State .wrapV ,
280- phi3Config .dim (),
281- phi3Config .headSize () * phi3Config .numberOfKeyValueHeads ());
282-
283302 // Fused Phi3 RoPE Rotation + KV Cache Write
284303 unifiedLayer .task ("rope_and_kv_cache" ,
285304 Phi3Kernels ::ropeRotationWithCacheCopyPhi3 ,
0 commit comments