Skip to content

Commit 577b6b1

Browse files
committed
Increase BLOCK_SIZE_C to 16 for Transformer kernel and update FP16 FFN task graphs by removing deprecated tasks, consolidating RMS normalization and FFN operations into rms_ffn_gate_up.
1 parent a1c94fb commit 577b6b1

File tree

2 files changed

+9
-32
lines changed

2 files changed

+9
-32
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ public static void processHeadsFlashAttention(KernelContext context, FloatArray
485485
int pos = positionHolder.get(0);
486486
int loff = layer * contextLength * kvDim;
487487
int kvHeadIdx = h / kvMul;
488-
int BLOCK_SIZE_C = 8;
488+
int BLOCK_SIZE_C = 16;
489489

490490
// Allocate shared memory for tiled computation
491491
float[] q_shared = context.allocateFloatLocalArray(headSize);

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
5757
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker);
5858
// === FFN Block ===
5959
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker);
60-
// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_apply", rmsNormWorker);
61-
// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_gate_up", configHiddenDimRowMajorWorker);
62-
6360
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker);
64-
65-
6661
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", configDimRowMajorGlobalWorker);
6762
}
6863
return tornadoForwardScheduler;
@@ -157,16 +152,11 @@ List<ImmutableTaskGraph> setupFFNLayered() {
157152
* └────────┬────────┘
158153
* │
159154
* ▼
160-
* ┌───────────────┐
161-
* │ ffn_rms_apply │──▶ wrapXb (normalized, FP32)
162-
* └───────┬───────┘
163-
* │
164-
* ▼
165-
* ┌─────────────┐
166-
* │ ffn_gate_up │──▶ wrapHb = SiLU(xb·W1) ⊙ (xb·W3)
167-
* └──────┬──────┘
168-
* │
169-
* ▼
155+
* ┌─────────────────┐
156+
* │ rms_ffn_gate_up │──▶ wrapHb = SiLU(RMSNorm(x)·W1) ⊙ (RMSNorm(x)·W3)
157+
* └────────┬────────┘ (fused: RMS apply + W1/W3 matmuls + SiLU + GLU)
158+
* │
159+
* ▼
170160
* ┌──────────────┐
171161
* │ ffn_down_proj│──▶ wrapX += W2 · wrapHb (residual connection)
172162
* └──────┬───────┘
@@ -176,16 +166,16 @@ List<ImmutableTaskGraph> setupFFNLayered() {
176166
*
177167
* ══════════════════════════════════════════════════════════════════════════════
178168
*
179-
* Task Count: 10 tasks (8 if NVIDIA, skipping rms_finalize steps)
169+
* Task Count: 9 tasks (7 if NVIDIA, skipping rms_finalize steps)
180170
*
181171
* Data Flow Summary:
182172
* Input: wrapX (FP32) - hidden state from previous layer
183173
* Output: wrapX (FP32) - updated hidden state with residual connections
184174
*
185175
* Key Fusion Points:
186-
* • qkv_projection: Fused Q/K/V matmuls (3→1 kernel)
176+
* • qkv_projection: Fused Q/K/V matmuls (3→1 kernel)
187177
* • rope_and_kv_cache: Fused RoPE rotation + cache write (2→1 kernel)
188-
* • ffn_gate_up: Fused W1/W3 matmuls + SiLU + GLU (3→1 kernel)
178+
* • rms_ffn_gate_up: Fused RMS apply + W1/W3 matmuls + SiLU + GLU (4→1 kernel)
189179
*
190180
*/
191181
TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, int layerIndex) {
@@ -275,19 +265,6 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config,
275265
context, state.tempFFN, config.dim(), config.rmsNormEps());
276266
}
277267

278-
// unifiedLayer.task("ffn_rms_apply",
279-
// TransformerComputeKernelsLayered::reductionOneBlock2WithLayer,
280-
// context, state.wrapXb, state.wrapX,
281-
// weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN);
282-
//
283-
// // Gate + Up projection with SiLU activation (W1, W3)
284-
// unifiedLayer.task("ffn_gate_up",
285-
// TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation,
286-
// context, state.wrapXb, state.wrapHb,
287-
// weights.w1Layered[layerIndex].asHalfFloatArray(),
288-
// weights.w3Layered[layerIndex].asHalfFloatArray(),
289-
// config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC);
290-
291268
unifiedLayer.task("rms_ffn_gate_up",
292269
TransformerComputeKernelsLayered::fusedRmsNormFFNGateUp,
293270
context,

0 commit comments

Comments
 (0)