Skip to content

Commit 02b1a2c

Browse files
committed
Add splitQKV and splitGateUpSiLU worker grids to Phi3 FP16 FFN layers and update grid scheduler configuration
1 parent e7d79c9 commit 02b1a2c

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,16 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) {
8989
int ffnDownGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
9090
WorkerGrid ffnDownWorker = WorkerGridFactory.genericWorker(ffnDownGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC);
9191

92+
WorkerGrid splitQKVWorker = WorkerGridFactory.genericWorker(opSize, 128);
93+
94+
// SplitGateUpAndSiLU worker
95+
WorkerGrid splitGateUpSiLUWorker = WorkerGridFactory.genericWorker(config.hiddenDim(), 128);
96+
97+
9298
// Map workers to tasks for each layer
9399
for (int i = 0; i < config.numberOfLayers(); i++) {
100+
gridScheduler.addWorkerGrid("layer_" + i + ".splitQKV", splitQKVWorker);
101+
gridScheduler.addWorkerGrid("layer_" + i + ".gateUpSiLU", splitGateUpSiLUWorker);
94102
gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker);
95103
gridScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker);
96104
gridScheduler.addWorkerGrid("layer_" + i + ".qkvmatmul", matmulQkvRowMajorWorker);

0 commit comments

Comments
 (0)