1111import com .carrotsearch .randomizedtesting .annotations .Name ;
1212
1313import org .apache .http .util .EntityUtils ;
14- import org .elasticsearch .Version ;
1514import org .elasticsearch .client .Request ;
1615import org .elasticsearch .client .Response ;
1716import org .elasticsearch .common .Strings ;
18- import org .elasticsearch .index .mapper .SourceFieldMapper ;
17+ import org .elasticsearch .index .mapper .MapperFeatures ;
18+ import org .elasticsearch .index .mapper .vectors .DenseVectorFieldMapper ;
1919import org .elasticsearch .xcontent .XContentBuilder ;
2020import org .elasticsearch .xcontent .XContentType ;
2121
2222import java .io .IOException ;
2323import java .nio .charset .StandardCharsets ;
24- import java .util .Map ;
25- import java .util .function .Predicate ;
24+ import java .util .Set ;
25+ import java .util .stream .Collectors ;
26+ import java .util .stream .Stream ;
2627
2728import static org .elasticsearch .rest .action .search .RestSearchAction .TOTAL_HITS_AS_INT_PARAM ;
2829
@@ -36,127 +37,239 @@ public class DenseVectorMappingUpdateIT extends AbstractRollingUpgradeTestCase {
3637
3738 private static final String BULK1 = """
3839 {"index": {"_id": "1"}}
39- {"embedding": [1, 1, 1, 1]}
40+ {"embedding": [1, 1, 1, 1, 1, 1, 1, 1 ]}
4041 {"index": {"_id": "2"}}
41- {"embedding": [1, 1, 1, 2]}
42+ {"embedding": [1, 1, 1, 1, 1, 1, 1, 2]}
4243 {"index": {"_id": "3"}}
43- {"embedding": [1, 1, 1, 3]}
44+ {"embedding": [1, 1, 1, 1, 1, 1, 1, 3]}
4445 {"index": {"_id": "4"}}
45- {"embedding": [1, 1, 1, 4]}
46+ {"embedding": [1, 1, 1, 1, 1, 1, 1, 4]}
4647 {"index": {"_id": "5"}}
47- {"embedding": [1, 1, 1, 5]}
48+ {"embedding": [1, 1, 1, 1, 1, 1, 1, 5]}
4849 {"index": {"_id": "6"}}
49- {"embedding": [1, 1, 1, 6]}
50+ {"embedding": [1, 1, 1, 1, 1, 1, 1, 6]}
5051 {"index": {"_id": "7"}}
51- {"embedding": [1, 1, 1, 7]}
52+ {"embedding": [1, 1, 1, 1, 1, 1, 1, 7]}
5253 {"index": {"_id": "8"}}
53- {"embedding": [1, 1, 1, 8]}
54+ {"embedding": [1, 1, 1, 1, 1, 1, 1, 8]}
5455 {"index": {"_id": "9"}}
55- {"embedding": [1, 1, 1, 9]}
56+ {"embedding": [1, 1, 1, 1, 1, 1, 1, 9]}
5657 {"index": {"_id": "10"}}
57- {"embedding": [1, 1, 1, 10]}
58+ {"embedding": [1, 1, 1, 1, 1, 1, 1, 10]}
59+ """ ;
60+
61+ private static final String BULK1_BIT = """
62+ {"index": {"_id": "1"}}
63+ {"embedding": [1]}
64+ {"index": {"_id": "2"}}
65+ {"embedding": [2]}
66+ {"index": {"_id": "3"}}
67+ {"embedding": [3]}
68+ {"index": {"_id": "4"}}
69+ {"embedding": [4]}
70+ {"index": {"_id": "5"}}
71+ {"embedding": [5]}
72+ {"index": {"_id": "6"}}
73+ {"embedding": [6]}
74+ {"index": {"_id": "7"}}
75+ {"embedding": [7]}
76+ {"index": {"_id": "8"}}
77+ {"embedding": [8]}
78+ {"index": {"_id": "9"}}
79+ {"embedding": [9]}
80+ {"index": {"_id": "10"}}
81+ {"embedding": [10]}
5882 """ ;
5983
6084 private static final String BULK2 = """
6185 {"index": {"_id": "11"}}
62- {"embedding": [1, 0, 1, 1]}
86+ {"embedding": [1, 1, 1, 1, 1, 0, 1, 1]}
6387 {"index": {"_id": "12"}}
64- {"embedding": [1, 2, 1, 1]}
88+ {"embedding": [1, 1, 1, 1, 1, 2, 1, 1]}
6589 {"index": {"_id": "13"}}
66- {"embedding": [1, 3, 1, 1]}
90+ {"embedding": [1, 1, 1, 1, 1, 3, 1, 1]}
6791 {"index": {"_id": "14"}}
68- {"embedding": [1, 4, 1, 1]}
92+ {"embedding": [1, 1, 1, 1, 1, 4, 1, 1]}
6993 {"index": {"_id": "15"}}
70- {"embedding": [1, 5, 1, 1]}
94+ {"embedding": [1, 1, 1, 1, 1, 5, 1, 1]}
7195 {"index": {"_id": "16"}}
72- {"embedding": [1, 6, 1, 1]}
96+ {"embedding": [1, 1, 1, 1, 1, 6, 1, 1]}
7397 {"index": {"_id": "17"}}
74- {"embedding": [1, 7, 1, 1]}
98+ {"embedding": [1, 1, 1, 1, 1, 7, 1, 1]}
7599 {"index": {"_id": "18"}}
76- {"embedding": [1, 8, 1, 1]}
100+ {"embedding": [1, 1, 1, 1, 1, 8, 1, 1]}
77101 {"index": {"_id": "19"}}
78- {"embedding": [1, 9, 1, 1]}
102+ {"embedding": [1, 1, 1, 1, 1, 9, 1, 1]}
79103 {"index": {"_id": "20"}}
80- {"embedding": [1, 10, 1, 1]}
104+ {"embedding": [1, 1, 1, 1, 1, 10, 1, 1]}
105+ """ ;
106+ private static final String BULK2_BIT = """
107+ {"index": {"_id": "11"}}
108+ {"embedding": [101]}
109+ {"index": {"_id": "12"}}
110+ {"embedding": [102]}
111+ {"index": {"_id": "13"}}
112+ {"embedding": [103]}
113+ {"index": {"_id": "14"}}
114+ {"embedding": [104]}
115+ {"index": {"_id": "15"}}
116+ {"embedding": [105]}
117+ {"index": {"_id": "16"}}
118+ {"embedding": [106]}
119+ {"index": {"_id": "17"}}
120+ {"embedding": [107]}
121+ {"index": {"_id": "18"}}
122+ {"embedding": [108]}
123+ {"index": {"_id": "19"}}
124+ {"embedding": [109]}
125+ {"index": {"_id": "20"}}
126+ {"embedding": [110]}
81127 """ ;
82128
83129 public DenseVectorMappingUpdateIT (@ Name ("upgradedNodes" ) int upgradedNodes ) {
84130 super (upgradedNodes );
85131 }
86132
87133 public void testDenseVectorMappingUpdateOnOldCluster () throws IOException {
88- if (oldClusterHasFeature ("gte_v8.7.1" )) {
89- String indexName = "test_index" ;
90- if (isOldCluster ()) {
91- Request createIndex = new Request ("PUT" , "/" + indexName );
92- boolean useSyntheticSource = randomBoolean () && oldClusterHasFeature (SYNTHETIC_SOURCE_FEATURE );
93-
94- boolean useIndexSetting = SourceFieldMapper .onOrAfterDeprecateModeVersion (getOldClusterIndexVersion ());
95- XContentBuilder payload = XContentBuilder .builder (XContentType .JSON .xContent ()).startObject ();
96- if (useSyntheticSource ) {
97- if (useIndexSetting ) {
134+ String indexName = "test_index_type_change" ;
135+ if (isOldCluster ()) {
136+ Request createIndex = new Request ("PUT" , "/" + indexName );
137+ boolean useSyntheticSource = randomBoolean ();
138+
139+ XContentBuilder payload = XContentBuilder .builder (XContentType .JSON .xContent ()).startObject ();
140+ if (useSyntheticSource ) {
141+ payload .startObject ("settings" ).field ("index.mapping.source.mode" , "synthetic" ).endObject ();
142+ }
143+ payload .startObject ("mappings" );
144+ payload .startObject ("properties" )
145+ .startObject ("embedding" )
146+ .field ("type" , "dense_vector" )
147+ .field ("index" , "true" )
148+ .field ("dims" , 8 )
149+ .field ("similarity" , "cosine" )
150+ .startObject ("index_options" )
151+ .field ("type" , "hnsw" )
152+ .field ("m" , "16" )
153+ .field ("ef_construction" , "100" )
154+ .endObject ()
155+ .endObject ()
156+ .endObject ()
157+ .endObject ()
158+ .endObject ();
159+ createIndex .setJsonEntity (Strings .toString (payload ));
160+ client ().performRequest (createIndex );
161+ Request index = new Request ("POST" , "/" + indexName + "/_bulk/" );
162+ index .addParameter ("refresh" , "true" );
163+ index .setJsonEntity (BULK1 );
164+ client ().performRequest (index );
165+ }
166+
167+ int expectedCount = 10 ;
168+
169+ assertCount (indexName , expectedCount );
170+
171+ if (isUpgradedCluster ()) {
172+ Request updateMapping = new Request ("PUT" , "/" + indexName + "/_mapping" );
173+ XContentBuilder mappings = XContentBuilder .builder (XContentType .JSON .xContent ())
174+ .startObject ()
175+ .startObject ("properties" )
176+ .startObject ("embedding" )
177+ .field ("type" , "dense_vector" )
178+ .field ("index" , "true" )
179+ .field ("dims" , 8 )
180+ .field ("similarity" , "cosine" )
181+ .startObject ("index_options" )
182+ .field ("type" , "int8_hnsw" )
183+ .field ("m" , "16" )
184+ .field ("ef_construction" , "100" )
185+ .endObject ()
186+ .endObject ()
187+ .endObject ()
188+ .endObject ();
189+ updateMapping .setJsonEntity (Strings .toString (mappings ));
190+ assertOK (client ().performRequest (updateMapping ));
191+ Request index = new Request ("POST" , "/" + indexName + "/_bulk/" );
192+ index .addParameter ("refresh" , "true" );
193+ index .setJsonEntity (BULK2 );
194+ assertOK (client ().performRequest (index ));
195+ expectedCount = 20 ;
196+ assertCount (indexName , expectedCount );
197+ }
198+ }
199+
200+ private record Index (String type , Set <String > elementTypes ) {}
201+
202+ private static final Set <String > ALL_ELEMENT_TYPES = Stream .of (DenseVectorFieldMapper .ElementType .values ())
203+ .map (Object ::toString )
204+ .collect (Collectors .toUnmodifiableSet ());
205+ private static final Set <Index > INDEXES = Set .of (
206+ new Index (null , ALL_ELEMENT_TYPES ),
207+ new Index ("hnsw" , ALL_ELEMENT_TYPES ),
208+ new Index ("int8_hnsw" , Set .of ("float" , "bfloat16" )),
209+ new Index ("int4_hnsw" , Set .of ("float" , "bfloat16" )),
210+ new Index ("flat" , ALL_ELEMENT_TYPES ),
211+ new Index ("int8_flat" , Set .of ("float" , "bfloat16" )),
212+ new Index ("int4_flat" , Set .of ("float" , "bfloat16" ))
213+ // new Index("bbq_hnsw", Set.of("float", "bfloat16")),
214+ // new Index("bbq_flat", Set.of("float", "bfloat16"))
215+ );
216+
217+ public void testDenseVectorIndexOverUpgrade () throws IOException {
218+ if (isOldCluster ()) {
219+ boolean useSyntheticSource = randomBoolean ();
220+
221+ for (Index i : INDEXES ) {
222+ for (String elementType : i .elementTypes ()) {
223+ if (clusterSupportsIndex (i .type (), elementType ) == false ) {
224+ continue ;
225+ }
226+
227+ String indexName = "test_index_" + i .type () + "_" + elementType ;
228+ Request createIndex = new Request ("PUT" , "/" + indexName );
229+
230+ XContentBuilder payload = XContentBuilder .builder (XContentType .JSON .xContent ()).startObject ();
231+ if (useSyntheticSource ) {
98232 payload .startObject ("settings" ).field ("index.mapping.source.mode" , "synthetic" ).endObject ();
99233 }
234+ payload .startObject ("mappings" );
235+ payload .startObject ("properties" )
236+ .startObject ("embedding" )
237+ .field ("type" , "dense_vector" )
238+ .field ("element_type" , elementType )
239+ .field ("index" , "true" )
240+ .field ("dims" , 8 )
241+ .field ("similarity" , "l2_norm" );
242+ if (i .type () != null ) {
243+ payload .startObject ("index_options" ).field ("type" , i .type ()).endObject ();
244+ }
245+ payload .endObject ().endObject ().endObject ().endObject ();
246+ createIndex .setJsonEntity (Strings .toString (payload ));
247+ client ().performRequest (createIndex );
248+ Request index = new Request ("POST" , "/" + indexName + "/_bulk/" );
249+ index .addParameter ("refresh" , "true" );
250+ index .setJsonEntity (elementType .equals ("bit" ) ? BULK1_BIT : BULK1 );
251+ client ().performRequest (index );
252+
253+ assertCount (indexName , 10 );
100254 }
101- payload .startObject ("mappings" );
102- if (useIndexSetting == false ) {
103- payload .startObject ("_source" );
104- payload .field ("mode" , "synthetic" );
105- payload .endObject ();
106- }
107- payload .startObject ("properties" )
108- .startObject ("embedding" )
109- .field ("type" , "dense_vector" )
110- .field ("index" , "true" )
111- .field ("dims" , 4 )
112- .field ("similarity" , "cosine" )
113- .startObject ("index_options" )
114- .field ("type" , "hnsw" )
115- .field ("m" , "16" )
116- .field ("ef_construction" , "100" )
117- .endObject ()
118- .endObject ()
119- .endObject ()
120- .endObject ()
121- .endObject ();
122- createIndex .setJsonEntity (Strings .toString (payload ));
123- client ().performRequest (createIndex );
124- Request index = new Request ("POST" , "/" + indexName + "/_bulk/" );
125- index .addParameter ("refresh" , "true" );
126- index .setJsonEntity (BULK1 );
127- client ().performRequest (index );
128255 }
256+ }
129257
130- int expectedCount = 10 ;
258+ if (isUpgradedCluster ()) {
259+ for (Index i : INDEXES ) {
260+ for (String elementType : i .elementTypes ()) {
261+ if (clusterSupportsIndex (i .type (), elementType ) == false ) {
262+ continue ;
263+ }
264+ String indexName = "test_index_" + i .type () + "_" + elementType ;
131265
132- assertCount (indexName , expectedCount );
266+ Request index = new Request ("POST" , "/" + indexName + "/_bulk/" );
267+ index .addParameter ("refresh" , "true" );
268+ index .setJsonEntity (elementType .equals ("bit" ) ? BULK2_BIT : BULK2 );
269+ assertOK (client ().performRequest (index ));
133270
134- if (isUpgradedCluster () && clusterSupportsDenseVectorTypeUpdate ()) {
135- Request updateMapping = new Request ("PUT" , "/" + indexName + "/_mapping" );
136- XContentBuilder mappings = XContentBuilder .builder (XContentType .JSON .xContent ())
137- .startObject ()
138- .startObject ("properties" )
139- .startObject ("embedding" )
140- .field ("type" , "dense_vector" )
141- .field ("index" , "true" )
142- .field ("dims" , 4 )
143- .field ("similarity" , "cosine" )
144- .startObject ("index_options" )
145- .field ("type" , "int8_hnsw" )
146- .field ("m" , "16" )
147- .field ("ef_construction" , "100" )
148- .endObject ()
149- .endObject ()
150- .endObject ()
151- .endObject ();
152- updateMapping .setJsonEntity (Strings .toString (mappings ));
153- assertOK (client ().performRequest (updateMapping ));
154- Request index = new Request ("POST" , "/" + indexName + "/_bulk/" );
155- index .addParameter ("refresh" , "true" );
156- index .setJsonEntity (BULK2 );
157- assertOK (client ().performRequest (index ));
158- expectedCount = 20 ;
159- assertCount (indexName , expectedCount );
271+ assertCount (indexName , 20 );
272+ }
160273 }
161274 }
162275 }
@@ -172,13 +285,10 @@ private void assertCount(String index, int count) throws IOException {
172285 );
173286 }
174287
175- private boolean clusterSupportsDenseVectorTypeUpdate () throws IOException {
176- Map <?, ?> response = entityAsMap (client ().performRequest (new Request ("GET" , "_nodes" )));
177- Map <?, ?> nodes = (Map <?, ?>) response .get ("nodes" );
178-
179- Predicate <Map <?, ?>> nodeSupportsBulkApi = n -> Version .fromString (n .get ("version" ).toString ()).onOrAfter (Version .V_8_15_0 );
180-
181- return nodes .values ().stream ().map (o -> (Map <?, ?>) o ).allMatch (nodeSupportsBulkApi );
288+ private static boolean clusterSupportsIndex (String type , String elementType ) {
289+ if (elementType .equals ("bfloat16" ) && oldClusterHasFeature (MapperFeatures .GENERIC_VECTOR_FORMAT ) == false ) {
290+ return false ;
291+ }
292+ return true ;
182293 }
183-
184294}
0 commit comments