Skip to content

Commit 53dd5e2

Browse files
authored
Use round-to-even for bfloat16 conversion (#138754)
1 parent 0a8c107 commit 53dd5e2

File tree

2 files changed

+108
-5
lines changed

2 files changed

+108
-5
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,25 @@ public final class BFloat16 {
1818
public static final int BYTES = Short.BYTES;
1919

2020
public static short floatToBFloat16(float f) {
21-
// this rounds towards 0
21+
// this rounds towards even
2222
// zero - zero exp, zero fraction
2323
// denormal - zero exp, non-zero fraction
2424
// infinity - all-1 exp, zero fraction
2525
// NaN - all-1 exp, non-zero fraction
26-
// the Float.NaN constant is 0x7fc0_0000, so this won't turn the most common NaN values into
27-
// infinities
28-
return (short) (Float.floatToIntBits(f) >>> 16);
26+
27+
// note that floatToIntBits doesn't maintain specific NaN values,
28+
// unlike floatToRawIntBits, but instead can return different NaN bit patterns.
29+
// this means that a NaN is unlikely to be turned into infinity by rounding
30+
31+
int bits = Float.floatToIntBits(f);
32+
// with thanks to https://github.com/microsoft/onnxruntime Fp16Conversions
33+
int roundingBias = 0x7fff + ((bits >> 16) & 1);
34+
bits += roundingBias;
35+
return (short) (bits >> 16);
2936
}
3037

3138
public static float truncateToBFloat16(float f) {
32-
return Float.intBitsToFloat(Float.floatToIntBits(f) & 0xffff0000);
39+
return Float.intBitsToFloat(floatToBFloat16(f) << 16);
3340
}
3441

3542
public static float bFloat16ToFloat(short bf) {
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.codec.vectors;
11+
12+
import org.elasticsearch.test.ESTestCase;
13+
14+
import static org.hamcrest.Matchers.equalTo;
15+
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
16+
import static org.hamcrest.Matchers.lessThan;
17+
18+
public class BFloat16Tests extends ESTestCase {
19+
20+
public void testRoundToEven() {
21+
int exp = 0b001111110; // to create floating numbers around 1.0
22+
23+
// exact bfloat16 value
24+
float bfloat16 = construct(exp, 0b1111001_00000000_00000000);
25+
assertRounding(bfloat16, bfloat16);
26+
27+
// some FP examples
28+
assertRounding(1.003f, 1.0f);
29+
assertRounding(1.004f, 1.0078125f);
30+
31+
// round down
32+
assertRounding(construct(exp, 0b0000001_01111111_11111111), construct(exp, 0b0000001_00000000_00000000));
33+
34+
// round up
35+
assertRounding(construct(exp, 0b0000001_10000000_00000001), construct(exp, 0b0000010_00000000_00000000));
36+
37+
// split down to even
38+
assertRounding(construct(exp, 0b000010_10000000_00000000), construct(exp, 0b000010_00000000_00000000));
39+
40+
// split up to even
41+
assertRounding(construct(exp, 0b000001_10000000_00000000), construct(exp, 0b000010_00000000_00000000));
42+
43+
// round up, overflowing into exponent
44+
assertRounding(construct(0b000111111, 0b1111111_10000000_00000000), construct(0b001000000, 0b0000000_00000000_00000000));
45+
46+
// round up, overflowing from denormal to normal number
47+
assertRounding(construct(0b000000000, 0b1111111_10000000_00000000), construct(0b000000001, 0b0000000_00000000_00000000));
48+
49+
// round to positive infinity
50+
assertThat(BFloat16.truncateToBFloat16(construct(0b011111110, 0b1111111_10000000_00000000)), equalTo(Float.POSITIVE_INFINITY));
51+
52+
// round to negative infinity
53+
assertThat(BFloat16.truncateToBFloat16(construct(0b111111110, 0b1111111_10000000_00000000)), equalTo(Float.NEGATIVE_INFINITY));
54+
55+
// round to zero
56+
assertRounding(construct(0b000000000, 0b0000000_10000000_00000000), 0f);
57+
58+
// rounding the standard NaN value should be unchanged
59+
assertThat(Float.floatToRawIntBits(BFloat16.truncateToBFloat16(Float.NaN)), equalTo(Float.floatToRawIntBits(Float.NaN)));
60+
61+
// you would expect this to be turned into infinity due to overflow, but instead
62+
// it stays a NaN with a different bit pattern due to using floatToIntBits rather than floatToRawIntBits
63+
// inside floatToBFloat16
64+
assertTrue(Float.isNaN(BFloat16.truncateToBFloat16(construct(0b011111111, 0b0000000_10000000_00000000))));
65+
}
66+
67+
private static float construct(int exp, int mantissa) {
68+
assert (exp & 0xfffffe00) == 0;
69+
assert (mantissa & 0xf8000000) == 0;
70+
return Float.intBitsToFloat((exp << 23) | mantissa);
71+
}
72+
73+
private static void assertRounding(float value, float expectedRounded) {
74+
assert (Float.floatToIntBits(expectedRounded) & 0xffff) == 0;
75+
76+
// rounded float value to check should be close to input value
77+
// this checks the bit representations in the tests are actually sensible
78+
assertThat(Math.abs(value - expectedRounded), lessThan(0.004f));
79+
80+
float rounded = BFloat16.truncateToBFloat16(value);
81+
82+
assertEquals(
83+
value + " rounded to " + rounded + ", not " + expectedRounded,
84+
Float.floatToIntBits(expectedRounded),
85+
Float.floatToIntBits(rounded)
86+
);
87+
88+
// there should not be a closer bfloat16 value (comparing using FP math) than the expected rounded value
89+
float delta = Math.abs(value - rounded);
90+
float higherValue = Float.intBitsToFloat(Float.floatToIntBits(rounded) + 0x10000);
91+
assertThat(Math.abs(value - higherValue), greaterThanOrEqualTo(delta));
92+
93+
float lowerValue = Float.intBitsToFloat(Float.floatToIntBits(rounded) - 0x10000);
94+
assertThat(Math.abs(value - lowerValue), greaterThanOrEqualTo(delta));
95+
}
96+
}

0 commit comments

Comments
 (0)