@@ -57,12 +57,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
5757 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".attn_output_proj" , configDimRowMajorGlobalWorker );
5858 // === FFN Block ===
5959 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".ffn_rms_reduce" , rmsNormWorker );
60- // tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_apply", rmsNormWorker);
61- // tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_gate_up", configHiddenDimRowMajorWorker);
62-
6360 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".rms_ffn_gate_up" , configHiddenDimRowMajorWorker );
64-
65-
6661 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".ffn_down_proj" , configDimRowMajorGlobalWorker );
6762 }
6863 return tornadoForwardScheduler ;
@@ -157,16 +152,11 @@ List<ImmutableTaskGraph> setupFFNLayered() {
157152 * └────────┬────────┘
158153 * │
159154 * ▼
160- * ┌───────────────┐
161- * │ ffn_rms_apply │──▶ wrapXb (normalized, FP32)
162- * └───────┬───────┘
163- * │
164- * ▼
165- * ┌─────────────┐
166- * │ ffn_gate_up │──▶ wrapHb = SiLU(xb·W1) ⊙ (xb·W3)
167- * └──────┬──────┘
168- * │
169- * ▼
155+ * ┌─────────────────┐
156+ * │ rms_ffn_gate_up │──▶ wrapHb = SiLU(RMSNorm(x)·W1) ⊙ (RMSNorm(x)·W3)
157+ * └────────┬────────┘ (fused: RMS apply + W1/W3 matmuls + SiLU + GLU)
158+ * │
159+ * ▼
170160 * ┌──────────────┐
171161 * │ ffn_down_proj│──▶ wrapX += W2 · wrapHb (residual connection)
172162 * └──────┬───────┘
@@ -176,16 +166,16 @@ List<ImmutableTaskGraph> setupFFNLayered() {
176166 *
177167 * ══════════════════════════════════════════════════════════════════════════════
178168 *
179- * Task Count: 10 tasks (8 if NVIDIA, skipping rms_finalize steps)
169+ * Task Count: 9 tasks (7 if NVIDIA, skipping rms_finalize steps)
180170 *
181171 * Data Flow Summary:
182172 * Input: wrapX (FP32) - hidden state from previous layer
183173 * Output: wrapX (FP32) - updated hidden state with residual connections
184174 *
185175 * Key Fusion Points:
186- * • qkv_projection: Fused Q/K/V matmuls (3→1 kernel)
176+ * • qkv_projection: Fused Q/K/V matmuls (3→1 kernel)
187177 * • rope_and_kv_cache: Fused RoPE rotation + cache write (2→1 kernel)
188- * • ffn_gate_up: Fused W1/W3 matmuls + SiLU + GLU (3 →1 kernel)
178+ * • rms_ffn_gate_up: Fused RMS apply + W1/W3 matmuls + SiLU + GLU (4 →1 kernel)
189179 *
190180 */
191181 TaskGraph setupSingleFFNLayer (LlamaTornadoWeights weights , Configuration config , int layerIndex ) {
@@ -275,19 +265,6 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config,
275265 context , state .tempFFN , config .dim (), config .rmsNormEps ());
276266 }
277267
278- // unifiedLayer.task("ffn_rms_apply",
279- // TransformerComputeKernelsLayered::reductionOneBlock2WithLayer,
280- // context, state.wrapXb, state.wrapX,
281- // weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN);
282- //
283- // // Gate + Up projection with SiLU activation (W1, W3)
284- // unifiedLayer.task("ffn_gate_up",
285- // TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation,
286- // context, state.wrapXb, state.wrapHb,
287- // weights.w1Layered[layerIndex].asHalfFloatArray(),
288- // weights.w3Layered[layerIndex].asHalfFloatArray(),
289- // config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC);
290-
291268 unifiedLayer .task ("rms_ffn_gate_up" ,
292269 TransformerComputeKernelsLayered ::fusedRmsNormFFNGateUp ,
293270 context ,
0 commit comments