2121
2222import java .io .IOException ;
2323import java .nio .charset .StandardCharsets ;
24+ import java .util .Arrays ;
25+ import java .util .OptionalInt ;
2426import java .util .Set ;
2527import java .util .stream .Collectors ;
28+ import java .util .stream .IntStream ;
2629import java .util .stream .Stream ;
2730
2831import static org .elasticsearch .rest .action .search .RestSearchAction .TOTAL_HITS_AS_INT_PARAM ;
3336 */
3437public class DenseVectorMappingUpdateIT extends AbstractRollingUpgradeTestCase {
3538
36- private static final String SYNTHETIC_SOURCE_FEATURE = "gte_v8.12.0" ;
39+ private static String generateBulkData (int upgradedNodes , int dimensions ) {
40+ StringBuilder sb = new StringBuilder ();
3741
38- private static final String BULK1 = """
39- {"index": {"_id": "1"}}
40- {"embedding": [1, 1, 1, 1, 1, 1, 1, 1]}
41- {"index": {"_id": "2"}}
42- {"embedding": [1, 1, 1, 1, 1, 1, 1, 2]}
43- {"index": {"_id": "3"}}
44- {"embedding": [1, 1, 1, 1, 1, 1, 1, 3]}
45- {"index": {"_id": "4"}}
46- {"embedding": [1, 1, 1, 1, 1, 1, 1, 4]}
47- {"index": {"_id": "5"}}
48- {"embedding": [1, 1, 1, 1, 1, 1, 1, 5]}
49- {"index": {"_id": "6"}}
50- {"embedding": [1, 1, 1, 1, 1, 1, 1, 6]}
51- {"index": {"_id": "7"}}
52- {"embedding": [1, 1, 1, 1, 1, 1, 1, 7]}
53- {"index": {"_id": "8"}}
54- {"embedding": [1, 1, 1, 1, 1, 1, 1, 8]}
55- {"index": {"_id": "9"}}
56- {"embedding": [1, 1, 1, 1, 1, 1, 1, 9]}
57- {"index": {"_id": "10"}}
58- {"embedding": [1, 1, 1, 1, 1, 1, 1, 10]}
59- """ ;
42+ int [] vals = new int [dimensions ];
43+ Arrays .fill (vals , 1 );
6044
61- private static final String BULK1_BIT = """
62- {"index": {"_id": "1"}}
63- {"embedding": [1]}
64- {"index": {"_id": "2"}}
65- {"embedding": [2]}
66- {"index": {"_id": "3"}}
67- {"embedding": [3]}
68- {"index": {"_id": "4"}}
69- {"embedding": [4]}
70- {"index": {"_id": "5"}}
71- {"embedding": [5]}
72- {"index": {"_id": "6"}}
73- {"embedding": [6]}
74- {"index": {"_id": "7"}}
75- {"embedding": [7]}
76- {"index": {"_id": "8"}}
77- {"embedding": [8]}
78- {"index": {"_id": "9"}}
79- {"embedding": [9]}
80- {"index": {"_id": "10"}}
81- {"embedding": [10]}
82- """ ;
45+ // 1-10, 11-20, 21-30...
46+ IntStream docs = IntStream .rangeClosed (1 + (upgradedNodes * 10 ), (upgradedNodes + 1 ) * 10 );
8347
84- private static final String BULK2 = """
85- {"index": {"_id": "11"}}
86- {"embedding": [1, 1, 1, 1, 1, 0, 1, 1]}
87- {"index": {"_id": "12"}}
88- {"embedding": [1, 1, 1, 1, 1, 2, 1, 1]}
89- {"index": {"_id": "13"}}
90- {"embedding": [1, 1, 1, 1, 1, 3, 1, 1]}
91- {"index": {"_id": "14"}}
92- {"embedding": [1, 1, 1, 1, 1, 4, 1, 1]}
93- {"index": {"_id": "15"}}
94- {"embedding": [1, 1, 1, 1, 1, 5, 1, 1]}
95- {"index": {"_id": "16"}}
96- {"embedding": [1, 1, 1, 1, 1, 6, 1, 1]}
97- {"index": {"_id": "17"}}
98- {"embedding": [1, 1, 1, 1, 1, 7, 1, 1]}
99- {"index": {"_id": "18"}}
100- {"embedding": [1, 1, 1, 1, 1, 8, 1, 1]}
101- {"index": {"_id": "19"}}
102- {"embedding": [1, 1, 1, 1, 1, 9, 1, 1]}
103- {"index": {"_id": "20"}}
104- {"embedding": [1, 1, 1, 1, 1, 10, 1, 1]}
105- """ ;
106- private static final String BULK2_BIT = """
107- {"index": {"_id": "11"}}
108- {"embedding": [101]}
109- {"index": {"_id": "12"}}
110- {"embedding": [102]}
111- {"index": {"_id": "13"}}
112- {"embedding": [103]}
113- {"index": {"_id": "14"}}
114- {"embedding": [104]}
115- {"index": {"_id": "15"}}
116- {"embedding": [105]}
117- {"index": {"_id": "16"}}
118- {"embedding": [106]}
119- {"index": {"_id": "17"}}
120- {"embedding": [107]}
121- {"index": {"_id": "18"}}
122- {"embedding": [108]}
123- {"index": {"_id": "19"}}
124- {"embedding": [109]}
125- {"index": {"_id": "20"}}
126- {"embedding": [110]}
127- """ ;
48+ for (var it = docs .iterator (); it .hasNext ();) {
49+ vals [upgradedNodes ]++;
50+
51+ sb .append ("{\" index\" : {\" _id\" : \" " ).append (it .nextInt ()).append ("\" }}" );
52+ sb .append (System .lineSeparator ());
53+ sb .append ("{\" embedding\" : " ).append (Arrays .toString (vals )).append ("}" );
54+ sb .append (System .lineSeparator ());
55+ }
56+
57+ return sb .toString ();
58+ }
59+
60+ private final int upgradedNodes ;
12861
12962 public DenseVectorMappingUpdateIT (@ Name ("upgradedNodes" ) int upgradedNodes ) {
13063 super (upgradedNodes );
64+ this .upgradedNodes = upgradedNodes ;
13165 }
13266
13367 public void testDenseVectorMappingUpdateOnOldCluster () throws IOException {
@@ -160,7 +94,7 @@ public void testDenseVectorMappingUpdateOnOldCluster() throws IOException {
16094 client ().performRequest (createIndex );
16195 Request index = new Request ("POST" , "/" + indexName + "/_bulk/" );
16296 index .addParameter ("refresh" , "true" );
163- index .setJsonEntity (BULK1 );
97+ index .setJsonEntity (generateBulkData ( upgradedNodes , 8 ) );
16498 client ().performRequest (index );
16599 }
166100
@@ -190,7 +124,7 @@ public void testDenseVectorMappingUpdateOnOldCluster() throws IOException {
190124 assertOK (client ().performRequest (updateMapping ));
191125 Request index = new Request ("POST" , "/" + indexName + "/_bulk/" );
192126 index .addParameter ("refresh" , "true" );
193- index .setJsonEntity (BULK2 );
127+ index .setJsonEntity (generateBulkData ( upgradedNodes , 8 ) );
194128 assertOK (client ().performRequest (index ));
195129 expectedCount = 20 ;
196130 assertCount (indexName , expectedCount );
@@ -209,9 +143,9 @@ private record Index(String type, Set<String> elementTypes) {}
209143 new Index ("int4_hnsw" , Set .of ("float" , "bfloat16" )),
210144 new Index ("flat" , ALL_ELEMENT_TYPES ),
211145 new Index ("int8_flat" , Set .of ("float" , "bfloat16" )),
212- new Index ("int4_flat" , Set .of ("float" , "bfloat16" ))
213- // new Index("bbq_hnsw", Set.of("float", "bfloat16")),
214- // new Index("bbq_flat", Set.of("float", "bfloat16"))
146+ new Index ("int4_flat" , Set .of ("float" , "bfloat16" )),
147+ new Index ("bbq_hnsw" , Set .of ("float" , "bfloat16" )),
148+ new Index ("bbq_flat" , Set .of ("float" , "bfloat16" ))
215149 );
216150
217151 public void testDenseVectorIndexOverUpgrade () throws IOException {
@@ -220,10 +154,10 @@ public void testDenseVectorIndexOverUpgrade() throws IOException {
220154
221155 for (Index i : INDEXES ) {
222156 for (String elementType : i .elementTypes ()) {
223- if (clusterSupportsIndex (i .type (), elementType ) == false ) {
157+ var dims = getDimensions (i .type (), elementType );
158+ if (dims .isEmpty ()) {
224159 continue ;
225160 }
226-
227161 String indexName = "test_index_" + i .type () + "_" + elementType ;
228162 Request createIndex = new Request ("PUT" , "/" + indexName );
229163
@@ -237,43 +171,46 @@ public void testDenseVectorIndexOverUpgrade() throws IOException {
237171 .field ("type" , "dense_vector" )
238172 .field ("element_type" , elementType )
239173 .field ("index" , "true" )
240- .field ("dims" , 8 )
174+ .field ("dims" , elementType . equals ( "bit" ) ? dims . getAsInt () * 8 : dims . getAsInt () )
241175 .field ("similarity" , "l2_norm" );
242176 if (i .type () != null ) {
243177 payload .startObject ("index_options" ).field ("type" , i .type ()).endObject ();
244178 }
245179 payload .endObject ().endObject ().endObject ().endObject ();
246180 createIndex .setJsonEntity (Strings .toString (payload ));
247181 client ().performRequest (createIndex );
248- Request index = new Request ("POST" , "/" + indexName + "/_bulk/" );
249- index .addParameter ("refresh" , "true" );
250- index .setJsonEntity (elementType .equals ("bit" ) ? BULK1_BIT : BULK1 );
251- client ().performRequest (index );
252-
253- assertCount (indexName , 10 );
254182 }
255183 }
256184 }
257185
258- if ( isUpgradedCluster () ) {
259- for (Index i : INDEXES ) {
260- for ( String elementType : i . elementTypes ()) {
261- if (clusterSupportsIndex ( i . type (), elementType ) == false ) {
262- continue ;
263- }
264- String indexName = "test_index_" + i .type () + "_" + elementType ;
186+ for ( Index i : INDEXES ) {
187+ for (String elementType : i . elementTypes () ) {
188+ var dims = getDimensions ( i . type (), elementType );
189+ if (dims . isEmpty () ) {
190+ continue ;
191+ }
192+ String indexName = "test_index_" + i .type () + "_" + elementType ;
265193
266- Request index = new Request ("POST" , "/" + indexName + "/_bulk/" );
267- index .addParameter ("refresh" , "true" );
268- index .setJsonEntity (elementType . equals ( "bit" ) ? BULK2_BIT : BULK2 );
269- assertOK (client ().performRequest (index ));
194+ Request index = new Request ("POST" , "/" + indexName + "/_bulk/" );
195+ index .addParameter ("refresh" , "true" );
196+ index .setJsonEntity (generateBulkData ( upgradedNodes , dims . getAsInt ()) );
197+ assertOK (client ().performRequest (index ));
270198
271- assertCount (indexName , 20 );
272- }
199+ assertCount (indexName , (upgradedNodes + 1 ) * 10 );
273200 }
274201 }
275202 }
276203
204+ private OptionalInt getDimensions (String type , String elementType ) {
205+ if (elementType .equals ("bfloat16" ) && oldClusterHasFeature (MapperFeatures .GENERIC_VECTOR_FORMAT ) == false ) {
206+ return OptionalInt .empty ();
207+ }
208+ if (type != null && type .startsWith ("bbq_" )) {
209+ return OptionalInt .of (64 );
210+ }
211+ return OptionalInt .of (8 );
212+ }
213+
277214 private void assertCount (String index , int count ) throws IOException {
278215 Request searchTestIndexRequest = new Request ("POST" , "/" + index + "/_search" );
279216 searchTestIndexRequest .addParameter (TOTAL_HITS_AS_INT_PARAM , "true" );
@@ -284,11 +221,4 @@ private void assertCount(String index, int count) throws IOException {
284221 EntityUtils .toString (searchTestIndexResponse .getEntity (), StandardCharsets .UTF_8 )
285222 );
286223 }
287-
288- private static boolean clusterSupportsIndex (String type , String elementType ) {
289- if (elementType .equals ("bfloat16" ) && oldClusterHasFeature (MapperFeatures .GENERIC_VECTOR_FORMAT ) == false ) {
290- return false ;
291- }
292- return true ;
293- }
294224}
0 commit comments