1+ package org .beehive .gpullama3 .tornadovm .kernels ;
2+
3+ import uk .ac .manchester .tornado .api .KernelContext ;
4+ import uk .ac .manchester .tornado .api .math .TornadoMath ;
5+ import uk .ac .manchester .tornado .api .types .arrays .FloatArray ;
6+ import uk .ac .manchester .tornado .api .types .arrays .HalfFloatArray ;
7+ import uk .ac .manchester .tornado .api .types .arrays .IntArray ;
8+
9+ /**
10+ * Phi3Kernels: Optimized GPU kernels for Phi3 model family.
11+ *
12+ * <p>Key differences from Qwen/Llama kernels:</p>
13+ * <ul>
14+ * <li>Generic fused RMS + matmul (single output matrix)</li>
15+ * <li>Phi3 RoPE with headSize/2 offset pattern</li>
16+ * <li>Combined gate/up structure support</li>
17+ * </ul>
18+ */
19+ public class Phi3Kernels {
20+
21+ /**
22+ * Fused RMSNorm apply + single matrix-vector multiplication.
23+ *
24+ * <p>Combines RMS normalization application with a generic matmul in one kernel,
25+ * reducing memory bandwidth by avoiding intermediate storage.</p>
26+ *
27+ * <p>Formula: output[row] = sum_j(W[row,j] * rmsWeight[j] * scale * x[j])</p>
28+ *
29+ * <p>Use cases:</p>
30+ * <ul>
31+ * <li>Phi3 combined QKV projection (output = wqkv · RMSNorm(x))</li>
32+ * <li>Phi3 combined gate/up projection (output = wUp · RMSNorm(x))</li>
33+ * <li>Any single-matrix projection after RMSNorm</li>
34+ * </ul>
35+ *
36+ * @param context Kernel execution context
37+ * @param x Input hidden state (FP32) [dim]
38+ * @param output Output buffer (FP32) [outputDim]
39+ * @param rmsWeights RMS normalization weights (FP32) [dim]
40+ * @param rmsScale Precomputed RMS scale factor [1] (from reduction kernel)
41+ * @param w Weight matrix (FP16) [outputDim × dim]
42+ * @param inputDim Input dimension (dim)
43+ * @param outputDim Output dimension
44+ * @param localWorkGroupSize Local work group size for reduction
45+ */
46+ public static void fusedRmsNormMatmul (
47+ KernelContext context ,
48+ FloatArray x , // input (FP32)
49+ FloatArray output , // output (FP32)
50+ FloatArray rmsWeights , // RMS norm weights
51+ FloatArray rmsScale , // temp[0] = scale factor
52+ HalfFloatArray w , // weight matrix
53+ int inputDim , // input dimension
54+ int outputDim , // output dimension
55+ int localWorkGroupSize ) {
56+
57+ int rowId = context .groupIdx ;
58+ int localId = context .localIdx ;
59+
60+ if (rowId >= outputDim ) {
61+ return ;
62+ }
63+
64+ float scale = rmsScale .get (0 );
65+
66+ // Allocate shared memory for reduction
67+ float [] localSum = context .allocateFloatLocalArray (localWorkGroupSize );
68+
69+ int rowOffset = rowId * inputDim ;
70+
71+ // Each thread computes partial dot product with inline normalization
72+ float partialSum = 0.0f ;
73+ for (int j = localId ; j < inputDim ; j += localWorkGroupSize ) {
74+ float normalized = rmsWeights .get (j ) * scale * x .get (j );
75+ partialSum += w .get (rowOffset + j ).getFloat32 () * normalized ;
76+ }
77+
78+ localSum [localId ] = partialSum ;
79+ context .localBarrier ();
80+
81+ // Parallel reduction within workgroup
82+ for (int stride = localWorkGroupSize / 2 ; stride > 0 ; stride >>= 1 ) {
83+ if (localId < stride ) {
84+ localSum [localId ] += localSum [localId + stride ];
85+ }
86+ context .localBarrier ();
87+ }
88+
89+ // Thread 0 writes final result
90+ if (localId == 0 ) {
91+ output .set (rowId , localSum [0 ]);
92+ }
93+ }
94+
95+ /**
96+ * Phi3 RoPE rotation with fused KV cache copy.
97+ *
98+ * <p>Phi3 uses a different RoPE pattern than Llama/Qwen:</p>
99+ * <ul>
100+ * <li>Pairs elements with offset headSize/2 (not adjacent pairs)</li>
101+ * <li>Each thread processes one dimension pair across all heads</li>
102+ * <li>Iterates over heads internally</li>
103+ * </ul>
104+ *
105+ * <p>This fused kernel combines:</p>
106+ * <ul>
107+ * <li>Phi3-style RoPE rotation for Q and K</li>
108+ * <li>Direct cache write for rotated K</li>
109+ * <li>Direct cache copy for V (no rotation)</li>
110+ * </ul>
111+ *
112+ * @param context Kernel execution context
113+ * @param positionHolder Current position in sequence [1]
114+ * @param sq Query vectors (in/out, rotated) [dim]
115+ * @param sk Key vectors (in/out, rotated) [kvDim]
116+ * @param sv Value vectors (in only) [kvDim]
117+ * @param keyCache Key cache (out) [layers × contextLength × kvDim]
118+ * @param valueCache Value cache (out) [layers × contextLength × kvDim]
119+ * @param nHeadKv Number of KV heads
120+ * @param headSize Dimension per head
121+ * @param kvDim Total KV dimension (nHeadKv × headSize)
122+ * @param layer Current layer index
123+ * @param contextLength Maximum sequence length
124+ */
125+ public static void ropeRotationWithCacheCopyPhi3 (
126+ KernelContext context ,
127+ IntArray positionHolder ,
128+ FloatArray sq , // Q vector (in/out)
129+ FloatArray sk , // K vector (in/out)
130+ FloatArray sv , // V vector (in only)
131+ FloatArray keyCache , // Key cache (out)
132+ FloatArray valueCache , // Value cache (out)
133+ int nHeadKv ,
134+ int headSize ,
135+ int kvDim ,
136+ int layer ,
137+ int contextLength ) {
138+
139+ int idx = context .globalIdx ;
140+ int dimHalf = headSize / 2 ;
141+
142+ // Each thread processes one dimension pair
143+ if (idx >= dimHalf ) {
144+ return ;
145+ }
146+
147+ int pos = positionHolder .get (0 );
148+ int cacheOffset = layer * contextLength * kvDim + pos * kvDim ;
149+
150+ // Calculate frequency for this dimension
151+ float freq = 1.0f / TornadoMath .pow (10000.0f , (float ) (idx * 2 ) / (float ) headSize );
152+ float val = pos * freq ;
153+ float fcr = TornadoMath .cos (val );
154+ float fci = TornadoMath .sin (val );
155+
156+ // Process Q: all heads (dim = nHeads × headSize)
157+ int totalDimQ = sq .getSize ();
158+ for (int base = 0 ; base < totalDimQ ; base += headSize ) {
159+ if (base + idx >= totalDimQ || base + idx + dimHalf >= totalDimQ ) {
160+ break ;
161+ }
162+
163+ // Rotate Q with offset pattern
164+ float v0 = sq .get (base + idx );
165+ float v1 = sq .get (base + idx + dimHalf );
166+ sq .set (base + idx , v0 * fcr - v1 * fci );
167+ sq .set (base + idx + dimHalf , v0 * fci + v1 * fcr );
168+ }
169+
170+ // Process K: only kvDim elements, with cache write
171+ for (int base = 0 ; base < kvDim ; base += headSize ) {
172+ if (base + idx >= kvDim || base + idx + dimHalf >= kvDim ) {
173+ break ;
174+ }
175+
176+ // Rotate K with offset pattern
177+ float k0 = sk .get (base + idx );
178+ float k1 = sk .get (base + idx + dimHalf );
179+ float rotated0 = k0 * fcr - k1 * fci ;
180+ float rotated1 = k0 * fci + k1 * fcr ;
181+
182+ // Write rotated K back
183+ sk .set (base + idx , rotated0 );
184+ sk .set (base + idx + dimHalf , rotated1 );
185+
186+ // Fused cache write for K
187+ keyCache .set (cacheOffset + base + idx , rotated0 );
188+ keyCache .set (cacheOffset + base + idx + dimHalf , rotated1 );
189+
190+ // Fused cache copy for V (no rotation needed)
191+ valueCache .set (cacheOffset + base + idx , sv .get (base + idx ));
192+ valueCache .set (cacheOffset + base + idx + dimHalf , sv .get (base + idx + dimHalf ));
193+ }
194+ }
195+ }
0 commit comments