|
3 | 3 | import uk.ac.manchester.tornado.api.KernelContext; |
4 | 4 | import uk.ac.manchester.tornado.api.annotations.Parallel; |
5 | 5 | import uk.ac.manchester.tornado.api.math.TornadoMath; |
6 | | -import uk.ac.manchester.tornado.api.types.arrays.FloatArray; |
7 | | -import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; |
8 | | -import uk.ac.manchester.tornado.api.types.arrays.Int8Array; |
9 | | -import uk.ac.manchester.tornado.api.types.arrays.IntArray; |
| 6 | +import uk.ac.manchester.tornado.api.types.HalfFloat; |
| 7 | +import uk.ac.manchester.tornado.api.types.arrays.*; |
10 | 8 |
|
11 | 9 | public class TransformerComputeKernelsLayered { |
12 | 10 |
|
@@ -1015,6 +1013,101 @@ public static float matrixVectorRowMajorOptimizedQ8_0(KernelContext context, int |
1015 | 1013 | return localSums[0]; |
1016 | 1014 | } |
1017 | 1015 |
|
| 1016 | + public static void matrixVectorGenericQ8Byte(KernelContext context, FloatArray x, FloatArray output, ByteArray q, int dim1, int dim0, int localWorkGroupSize) { |
| 1017 | + int rowId = context.groupIdx; |
| 1018 | + int localId = context.localIdx; |
| 1019 | + |
| 1020 | + if (rowId >= dim0) { |
| 1021 | + return; |
| 1022 | + } |
| 1023 | + |
| 1024 | + float sum = matrixVectorRowMajorOptimizedQ8_0Byte(context, localWorkGroupSize, x, q, dim1); |
| 1025 | + |
| 1026 | + // Thread 0 writes the result |
| 1027 | + if (localId == 0) { |
| 1028 | + output.set(rowId, sum); |
| 1029 | + } |
| 1030 | + } |
| 1031 | + |
| 1032 | + public static float matrixVectorRowMajorOptimizedQ8_0Byte(KernelContext context, int localSize, FloatArray x, ByteArray q, int n) { |
| 1033 | + int rowId = context.groupIdx; |
| 1034 | + int localId = context.localIdx; |
| 1035 | + int blockSize = 32; |
| 1036 | + final int Q8_0_BLOCK_BYTES = 34; // 2 bytes scale + 32 bytes quants |
| 1037 | + |
| 1038 | + // Allocate local memory for reduction |
| 1039 | + float[] localSums = context.allocateFloatLocalArray(localSize); |
| 1040 | + |
| 1041 | + int blocksPerRow = (n + blockSize - 1) / blockSize; |
| 1042 | + int rowBlockOffset = rowId * blocksPerRow; // Starting block index for this row |
| 1043 | + |
| 1044 | + // 4-way unrolling |
| 1045 | + float partialSum1 = 0.0f; |
| 1046 | + float partialSum2 = 0.0f; |
| 1047 | + float partialSum3 = 0.0f; |
| 1048 | + float partialSum4 = 0.0f; |
| 1049 | + |
| 1050 | + // Main loop - process 4 elements at a time |
| 1051 | + for (int j = localId * 4; j < n - 3; j += localSize * 4) { |
| 1052 | + int blockIdx = j / blockSize; |
| 1053 | + int withinBlockIdx = j % blockSize; |
| 1054 | + |
| 1055 | + // Calculate byte offset for this Q8_0 block |
| 1056 | + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; |
| 1057 | + |
| 1058 | + // Load scale (first 2 bytes of block as HalfFloat) |
| 1059 | + HalfFloat scale = q.getHalf(blockByteOffset); |
| 1060 | + float scaleFloat = scale.getFloat32(); |
| 1061 | + |
| 1062 | + // Load 4 consecutive quantized values |
| 1063 | + int quantsOffset = blockByteOffset + 2 + withinBlockIdx; // Skip 2-byte scale |
| 1064 | + byte quant1 = q.get(quantsOffset); |
| 1065 | + byte quant2 = q.get(quantsOffset + 1); |
| 1066 | + byte quant3 = q.get(quantsOffset + 2); |
| 1067 | + byte quant4 = q.get(quantsOffset + 3); |
| 1068 | + |
| 1069 | + // Dequantize and multiply |
| 1070 | + partialSum1 += ((float) quant1 * scaleFloat) * x.get(j); |
| 1071 | + partialSum2 += ((float) quant2 * scaleFloat) * x.get(j + 1); |
| 1072 | + partialSum3 += ((float) quant3 * scaleFloat) * x.get(j + 2); |
| 1073 | + partialSum4 += ((float) quant4 * scaleFloat) * x.get(j + 3); |
| 1074 | + } |
| 1075 | + |
| 1076 | + float partialSum = partialSum1 + partialSum2 + partialSum3 + partialSum4; |
| 1077 | + |
| 1078 | + // Handle remaining elements |
| 1079 | + for (int j = ((n / 4) * 4) + localId; j < n; j += localSize) { |
| 1080 | + int blockIdx = j / blockSize; |
| 1081 | + int withinBlockIdx = j % blockSize; |
| 1082 | + |
| 1083 | + // Calculate byte offset for this Q8_0 block |
| 1084 | + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; |
| 1085 | + |
| 1086 | + // Load scale |
| 1087 | + HalfFloat scale = q.getHalf(blockByteOffset); |
| 1088 | + float scaleFloat = scale.getFloat32(); |
| 1089 | + |
| 1090 | + // Load quantized value |
| 1091 | + byte quant = q.get(blockByteOffset + 2 + withinBlockIdx); |
| 1092 | + |
| 1093 | + partialSum += ((float) quant * scaleFloat) * x.get(j); |
| 1094 | + } |
| 1095 | + |
| 1096 | + localSums[localId] = partialSum; |
| 1097 | + context.localBarrier(); |
| 1098 | + |
| 1099 | + // Parallel reduction |
| 1100 | + for (int stride = localSize / 2; stride > 0; stride >>= 1) { |
| 1101 | + if (localId < stride) { |
| 1102 | + localSums[localId] += localSums[localId + stride]; |
| 1103 | + } |
| 1104 | + context.localBarrier(); |
| 1105 | + } |
| 1106 | + |
| 1107 | + return localSums[0]; |
| 1108 | + |
| 1109 | + } |
| 1110 | + |
1018 | 1111 | public static void matrixVectorGenericWithResidual(KernelContext context, FloatArray x, FloatArray hb, Int8Array w_quants, HalfFloatArray w_scales, int n, int d, int localWorkGroupSize) { |
1019 | 1112 | // One row per workgroup (not per thread) |
1020 | 1113 | int rowId = context.groupIdx; |
|
0 commit comments