@@ -29,19 +29,14 @@ public LogitsFP16Layer(String name, State state, Weights weights, Configuration
2929 super (name , state , weights , config );
3030 this .lastTaskGraphID = lastTaskGraphID ;
3131 state .tempLogits .clear ();
32-
3332 var tornadoWeights = requireWeightsType (weights , TornadoWeights .class , "LogitsFP16Layer" , "TornadoTensor" );
3433 this .logitsTaskGraph = setupLogitsTaskGraph (tornadoWeights , config );
3534 this .schedulerType = schedulerType ;
3635 }
3736
38-
39- /**
40- * Builds the logits computation graph.
41- */
37+ // @formatter:off
4238 private TaskGraph setupLogitsTaskGraph (TornadoWeights weights , Configuration config ) {
43- TaskGraph logits = new TaskGraph ("logits" );
44-
39+ var logits = new TaskGraph ("logits" );
4540 // === Data Setup ===
4641 logits .consumeFromDevice (lastTaskGraphID , state .wrapX );
4742 logits .transferToDevice (DataTransferMode .FIRST_EXECUTION ,
@@ -97,24 +92,17 @@ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration con
9792 logits .transferToHost (DataTransferMode .EVERY_EXECUTION , state .wrapLogits );
9893 return logits ;
9994 }
100-
95+ // @formatter:on
10196
10297 @ Override
10398 public GridScheduler updateGridScheduler (GridScheduler tornadoForwardScheduler ) {
104- WorkerGrid logitsRMS ;
105- if (weights instanceof Qwen2TornadoWeights ) {
106- logitsRMS = WorkerGridFactory .createRmsNormWorker (config .dim (), 32 );
107- } else {
108- logitsRMS = WorkerGridFactory .createRmsNormWorker (config .dim (), 256 );
109- }
110-
99+ WorkerGrid logitsRMS = WorkerGridFactory .createRmsNormWorker (config .dim (), weights instanceof Qwen2TornadoWeights ? 32 : 256 );
111100 var vocabSizeRowMajor = config .vocabularySize () * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS ;
112- WorkerGrid vocabWorker = new WorkerGrid1D (vocabSizeRowMajor );
101+ var vocabWorker = new WorkerGrid1D (vocabSizeRowMajor );
113102 vocabWorker .setLocalWork (LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS , 1 , 1 );
114-
115- tornadoForwardScheduler .addWorkerGrid ("logits.vocab_proj" , vocabWorker );
116103 tornadoForwardScheduler .addWorkerGrid ("logits.rms_reduce" , logitsRMS );
117104 tornadoForwardScheduler .addWorkerGrid ("logits.rms_apply_fp16" , logitsRMS );
105+ tornadoForwardScheduler .addWorkerGrid ("logits.vocab_proj" , vocabWorker );
118106 return tornadoForwardScheduler ;
119107 }
120108
0 commit comments