Skip to content

Commit 7c63dc4

Browse files
committed
Refactor LogitsFP16Layer: streamline task graph setup, consolidate grid scheduler logic, and improve readability by adjusting formatting.
1 parent 1e46405 commit 7c63dc4

File tree

1 file changed

+6
-18
lines changed

1 file changed

+6
-18
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,14 @@ public LogitsFP16Layer(String name, State state, Weights weights, Configuration
2929
super(name, state, weights, config);
3030
this.lastTaskGraphID = lastTaskGraphID;
3131
state.tempLogits.clear();
32-
3332
var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsFP16Layer", "TornadoTensor");
3433
this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, config);
3534
this.schedulerType = schedulerType;
3635
}
3736

38-
39-
/**
40-
* Builds the logits computation graph.
41-
*/
37+
// @formatter:off
4238
private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) {
43-
TaskGraph logits = new TaskGraph("logits");
44-
39+
var logits = new TaskGraph("logits");
4540
// === Data Setup ===
4641
logits.consumeFromDevice(lastTaskGraphID, state.wrapX);
4742
logits.transferToDevice(DataTransferMode.FIRST_EXECUTION,
@@ -97,24 +92,17 @@ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration con
9792
logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits);
9893
return logits;
9994
}
100-
95+
// @formatter:on
10196

10297
@Override
10398
public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) {
104-
WorkerGrid logitsRMS;
105-
if (weights instanceof Qwen2TornadoWeights) {
106-
logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 32);
107-
} else {
108-
logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 256);
109-
}
110-
99+
WorkerGrid logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), weights instanceof Qwen2TornadoWeights ? 32 : 256);
111100
var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS;
112-
WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor);
101+
var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor);
113102
vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1);
114-
115-
tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker);
116103
tornadoForwardScheduler.addWorkerGrid("logits.rms_reduce", logitsRMS);
117104
tornadoForwardScheduler.addWorkerGrid("logits.rms_apply_fp16", logitsRMS);
105+
tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker);
118106
return tornadoForwardScheduler;
119107
}
120108

0 commit comments

Comments
 (0)