Skip to content

Commit 6c1ac6f

Browse files
committed
Add Phi3-specific fused kernels for RMSNorm+QKV and RMSNorm+Gate/Up, update Phi3 FP16 FFN layers with optimized worker grid configurations, fused workflows for attention and FFN blocks, and detailed task flow documentation.
1 parent 428e5cc commit 6c1ac6f

File tree

3 files changed

+470
-195
lines changed

3 files changed

+470
-195
lines changed
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
package org.beehive.gpullama3.tornadovm.kernels;
2+
3+
import uk.ac.manchester.tornado.api.KernelContext;
4+
import uk.ac.manchester.tornado.api.math.TornadoMath;
5+
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
6+
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
7+
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
8+
9+
/**
10+
* Phi3Kernels: Optimized GPU kernels for Phi3 model family.
11+
*
12+
* <p>Key differences from Qwen/Llama kernels:</p>
13+
* <ul>
14+
* <li>Generic fused RMS + matmul (single output matrix)</li>
15+
* <li>Phi3 RoPE with headSize/2 offset pattern</li>
16+
* <li>Combined gate/up structure support</li>
17+
* </ul>
18+
*/
19+
public class Phi3Kernels {
20+
21+
/**
22+
* Fused RMSNorm apply + single matrix-vector multiplication.
23+
*
24+
* <p>Combines RMS normalization application with a generic matmul in one kernel,
25+
* reducing memory bandwidth by avoiding intermediate storage.</p>
26+
*
27+
* <p>Formula: output[row] = sum_j(W[row,j] * rmsWeight[j] * scale * x[j])</p>
28+
*
29+
* <p>Use cases:</p>
30+
* <ul>
31+
* <li>Phi3 combined QKV projection (output = wqkv · RMSNorm(x))</li>
32+
* <li>Phi3 combined gate/up projection (output = wUp · RMSNorm(x))</li>
33+
* <li>Any single-matrix projection after RMSNorm</li>
34+
* </ul>
35+
*
36+
* @param context Kernel execution context
37+
* @param x Input hidden state (FP32) [dim]
38+
* @param output Output buffer (FP32) [outputDim]
39+
* @param rmsWeights RMS normalization weights (FP32) [dim]
40+
* @param rmsScale Precomputed RMS scale factor [1] (from reduction kernel)
41+
* @param w Weight matrix (FP16) [outputDim × dim]
42+
* @param inputDim Input dimension (dim)
43+
* @param outputDim Output dimension
44+
* @param localWorkGroupSize Local work group size for reduction
45+
*/
46+
public static void fusedRmsNormMatmul(
47+
KernelContext context,
48+
FloatArray x, // input (FP32)
49+
FloatArray output, // output (FP32)
50+
FloatArray rmsWeights, // RMS norm weights
51+
FloatArray rmsScale, // temp[0] = scale factor
52+
HalfFloatArray w, // weight matrix
53+
int inputDim, // input dimension
54+
int outputDim, // output dimension
55+
int localWorkGroupSize) {
56+
57+
int rowId = context.groupIdx;
58+
int localId = context.localIdx;
59+
60+
if (rowId >= outputDim) {
61+
return;
62+
}
63+
64+
float scale = rmsScale.get(0);
65+
66+
// Allocate shared memory for reduction
67+
float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize);
68+
69+
int rowOffset = rowId * inputDim;
70+
71+
// Each thread computes partial dot product with inline normalization
72+
float partialSum = 0.0f;
73+
for (int j = localId; j < inputDim; j += localWorkGroupSize) {
74+
float normalized = rmsWeights.get(j) * scale * x.get(j);
75+
partialSum += w.get(rowOffset + j).getFloat32() * normalized;
76+
}
77+
78+
localSum[localId] = partialSum;
79+
context.localBarrier();
80+
81+
// Parallel reduction within workgroup
82+
for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
83+
if (localId < stride) {
84+
localSum[localId] += localSum[localId + stride];
85+
}
86+
context.localBarrier();
87+
}
88+
89+
// Thread 0 writes final result
90+
if (localId == 0) {
91+
output.set(rowId, localSum[0]);
92+
}
93+
}
94+
95+
/**
96+
* Phi3 RoPE rotation with fused KV cache copy.
97+
*
98+
* <p>Phi3 uses a different RoPE pattern than Llama/Qwen:</p>
99+
* <ul>
100+
* <li>Pairs elements with offset headSize/2 (not adjacent pairs)</li>
101+
* <li>Each thread processes one dimension pair across all heads</li>
102+
* <li>Iterates over heads internally</li>
103+
* </ul>
104+
*
105+
* <p>This fused kernel combines:</p>
106+
* <ul>
107+
* <li>Phi3-style RoPE rotation for Q and K</li>
108+
* <li>Direct cache write for rotated K</li>
109+
* <li>Direct cache copy for V (no rotation)</li>
110+
* </ul>
111+
*
112+
* @param context Kernel execution context
113+
* @param positionHolder Current position in sequence [1]
114+
* @param sq Query vectors (in/out, rotated) [dim]
115+
* @param sk Key vectors (in/out, rotated) [kvDim]
116+
* @param sv Value vectors (in only) [kvDim]
117+
* @param keyCache Key cache (out) [layers × contextLength × kvDim]
118+
* @param valueCache Value cache (out) [layers × contextLength × kvDim]
119+
* @param nHeadKv Number of KV heads
120+
* @param headSize Dimension per head
121+
* @param kvDim Total KV dimension (nHeadKv × headSize)
122+
* @param layer Current layer index
123+
* @param contextLength Maximum sequence length
124+
*/
125+
public static void ropeRotationWithCacheCopyPhi3(
126+
KernelContext context,
127+
IntArray positionHolder,
128+
FloatArray sq, // Q vector (in/out)
129+
FloatArray sk, // K vector (in/out)
130+
FloatArray sv, // V vector (in only)
131+
FloatArray keyCache, // Key cache (out)
132+
FloatArray valueCache, // Value cache (out)
133+
int nHeadKv,
134+
int headSize,
135+
int kvDim,
136+
int layer,
137+
int contextLength) {
138+
139+
int idx = context.globalIdx;
140+
int dimHalf = headSize / 2;
141+
142+
// Each thread processes one dimension pair
143+
if (idx >= dimHalf) {
144+
return;
145+
}
146+
147+
int pos = positionHolder.get(0);
148+
int cacheOffset = layer * contextLength * kvDim + pos * kvDim;
149+
150+
// Calculate frequency for this dimension
151+
float freq = 1.0f / TornadoMath.pow(10000.0f, (float) (idx * 2) / (float) headSize);
152+
float val = pos * freq;
153+
float fcr = TornadoMath.cos(val);
154+
float fci = TornadoMath.sin(val);
155+
156+
// Process Q: all heads (dim = nHeads × headSize)
157+
int totalDimQ = sq.getSize();
158+
for (int base = 0; base < totalDimQ; base += headSize) {
159+
if (base + idx >= totalDimQ || base + idx + dimHalf >= totalDimQ) {
160+
break;
161+
}
162+
163+
// Rotate Q with offset pattern
164+
float v0 = sq.get(base + idx);
165+
float v1 = sq.get(base + idx + dimHalf);
166+
sq.set(base + idx, v0 * fcr - v1 * fci);
167+
sq.set(base + idx + dimHalf, v0 * fci + v1 * fcr);
168+
}
169+
170+
// Process K: only kvDim elements, with cache write
171+
for (int base = 0; base < kvDim; base += headSize) {
172+
if (base + idx >= kvDim || base + idx + dimHalf >= kvDim) {
173+
break;
174+
}
175+
176+
// Rotate K with offset pattern
177+
float k0 = sk.get(base + idx);
178+
float k1 = sk.get(base + idx + dimHalf);
179+
float rotated0 = k0 * fcr - k1 * fci;
180+
float rotated1 = k0 * fci + k1 * fcr;
181+
182+
// Write rotated K back
183+
sk.set(base + idx, rotated0);
184+
sk.set(base + idx + dimHalf, rotated1);
185+
186+
// Fused cache write for K
187+
keyCache.set(cacheOffset + base + idx, rotated0);
188+
keyCache.set(cacheOffset + base + idx + dimHalf, rotated1);
189+
190+
// Fused cache copy for V (no rotation needed)
191+
valueCache.set(cacheOffset + base + idx, sv.get(base + idx));
192+
valueCache.set(cacheOffset + base + idx + dimHalf, sv.get(base + idx + dimHalf));
193+
}
194+
}
195+
}

0 commit comments

Comments
 (0)