Skip to content

Commit af12e80

Browse files
committed
Fixes softmax LSE calculation by removing scale factor
Removes the multiplication by softmax_scale from the log-sum-exp calculation when sum is valid, keeping only the row maximum and log sum components. This corrects the mathematical formula to properly compute the LSE value without the unnecessary scaling factor that was affecting numerical accuracy.
1 parent 9a408be commit af12e80

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

csrc/src/softmax.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,9 @@ struct Softmax {
209209
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
210210
float sum = row_sum(mi);
211211
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
212-
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
212+
lse(mi) = (sum == 0.f || sum != sum)
213+
? (Split ? -INFINITY : INFINITY)
214+
: (row_max(mi) + __logf(sum));
213215
float scale = inv_sum;
214216
#pragma unroll
215217
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }

0 commit comments

Comments
 (0)