Skip to content

Commit 04db93d

Browse files
Add matrixVectorGenericQ8Byte and matrixVectorRowMajorOptimizedQ8_0Byte kernels for Q8_0 matrix-vector computations
1 parent 7456d59 commit 04db93d

File tree

1 file changed

+97
-4
lines changed

1 file changed

+97
-4
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java

Lines changed: 97 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33
import uk.ac.manchester.tornado.api.KernelContext;
44
import uk.ac.manchester.tornado.api.annotations.Parallel;
55
import uk.ac.manchester.tornado.api.math.TornadoMath;
6-
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
7-
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
8-
import uk.ac.manchester.tornado.api.types.arrays.Int8Array;
9-
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
6+
import uk.ac.manchester.tornado.api.types.HalfFloat;
7+
import uk.ac.manchester.tornado.api.types.arrays.*;
108

119
public class TransformerComputeKernelsLayered {
1210

@@ -1015,6 +1013,101 @@ public static float matrixVectorRowMajorOptimizedQ8_0(KernelContext context, int
10151013
return localSums[0];
10161014
}
10171015

1016+
public static void matrixVectorGenericQ8Byte(KernelContext context, FloatArray x, FloatArray output, ByteArray q, int dim1, int dim0, int localWorkGroupSize) {
1017+
int rowId = context.groupIdx;
1018+
int localId = context.localIdx;
1019+
1020+
if (rowId >= dim0) {
1021+
return;
1022+
}
1023+
1024+
float sum = matrixVectorRowMajorOptimizedQ8_0Byte(context, localWorkGroupSize, x, q, dim1);
1025+
1026+
// Thread 0 writes the result
1027+
if (localId == 0) {
1028+
output.set(rowId, sum);
1029+
}
1030+
}
1031+
1032+
public static float matrixVectorRowMajorOptimizedQ8_0Byte(KernelContext context, int localSize, FloatArray x, ByteArray q, int n) {
1033+
int rowId = context.groupIdx;
1034+
int localId = context.localIdx;
1035+
int blockSize = 32;
1036+
final int Q8_0_BLOCK_BYTES = 34; // 2 bytes scale + 32 bytes quants
1037+
1038+
// Allocate local memory for reduction
1039+
float[] localSums = context.allocateFloatLocalArray(localSize);
1040+
1041+
int blocksPerRow = (n + blockSize - 1) / blockSize;
1042+
int rowBlockOffset = rowId * blocksPerRow; // Starting block index for this row
1043+
1044+
// 4-way unrolling
1045+
float partialSum1 = 0.0f;
1046+
float partialSum2 = 0.0f;
1047+
float partialSum3 = 0.0f;
1048+
float partialSum4 = 0.0f;
1049+
1050+
// Main loop - process 4 elements at a time
1051+
for (int j = localId * 4; j < n - 3; j += localSize * 4) {
1052+
int blockIdx = j / blockSize;
1053+
int withinBlockIdx = j % blockSize;
1054+
1055+
// Calculate byte offset for this Q8_0 block
1056+
int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES;
1057+
1058+
// Load scale (first 2 bytes of block as HalfFloat)
1059+
HalfFloat scale = q.getHalf(blockByteOffset);
1060+
float scaleFloat = scale.getFloat32();
1061+
1062+
// Load 4 consecutive quantized values
1063+
int quantsOffset = blockByteOffset + 2 + withinBlockIdx; // Skip 2-byte scale
1064+
byte quant1 = q.get(quantsOffset);
1065+
byte quant2 = q.get(quantsOffset + 1);
1066+
byte quant3 = q.get(quantsOffset + 2);
1067+
byte quant4 = q.get(quantsOffset + 3);
1068+
1069+
// Dequantize and multiply
1070+
partialSum1 += ((float) quant1 * scaleFloat) * x.get(j);
1071+
partialSum2 += ((float) quant2 * scaleFloat) * x.get(j + 1);
1072+
partialSum3 += ((float) quant3 * scaleFloat) * x.get(j + 2);
1073+
partialSum4 += ((float) quant4 * scaleFloat) * x.get(j + 3);
1074+
}
1075+
1076+
float partialSum = partialSum1 + partialSum2 + partialSum3 + partialSum4;
1077+
1078+
// Handle remaining elements
1079+
for (int j = ((n / 4) * 4) + localId; j < n; j += localSize) {
1080+
int blockIdx = j / blockSize;
1081+
int withinBlockIdx = j % blockSize;
1082+
1083+
// Calculate byte offset for this Q8_0 block
1084+
int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES;
1085+
1086+
// Load scale
1087+
HalfFloat scale = q.getHalf(blockByteOffset);
1088+
float scaleFloat = scale.getFloat32();
1089+
1090+
// Load quantized value
1091+
byte quant = q.get(blockByteOffset + 2 + withinBlockIdx);
1092+
1093+
partialSum += ((float) quant * scaleFloat) * x.get(j);
1094+
}
1095+
1096+
localSums[localId] = partialSum;
1097+
context.localBarrier();
1098+
1099+
// Parallel reduction
1100+
for (int stride = localSize / 2; stride > 0; stride >>= 1) {
1101+
if (localId < stride) {
1102+
localSums[localId] += localSums[localId + stride];
1103+
}
1104+
context.localBarrier();
1105+
}
1106+
1107+
return localSums[0];
1108+
1109+
}
1110+
10181111
public static void matrixVectorGenericWithResidual(KernelContext context, FloatArray x, FloatArray hb, Int8Array w_quants, HalfFloatArray w_scales, int n, int d, int localWorkGroupSize) {
10191112
// One row per workgroup (not per thread)
10201113
int rowId = context.groupIdx;

0 commit comments

Comments
 (0)