Skip to content

Commit 2743e53

Browse files
Add late chunking configuration for JinaAI embedding task settings (#137263)
* Add late chunking configuration for JinaAI embedding task settings * Update docs/changelog/137263.yaml * Clean up tests and fix mutateInstance for JinaAIEmbeddingsTaskSettingsTests * Cleanup EmbeddingRequestChunker tests and disable late chunking for inputs exceeding max word count * Fixing test sentence generation * Adding test for generating multiple batches and clarification on late chunking word count limit
1 parent 3ed4361 commit 2743e53

File tree

13 files changed

+478
-92
lines changed

13 files changed

+478
-92
lines changed

docs/changelog/137263.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 137263
2+
summary: Add late chunking configuration for JinaAI embedding task settings
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
9222000
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
index_created_transport_version,9221000
1+
jina_ai_configurable_late_chunking,9222000

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/chunking/ChunkerUtils.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,16 @@
99

1010
import com.ibm.icu.text.BreakIterator;
1111

12+
import java.util.Locale;
13+
1214
public class ChunkerUtils {
1315

16+
public static int countWords(String text) {
17+
BreakIterator wordIterator = BreakIterator.getWordInstance(Locale.ROOT);
18+
wordIterator.setText(text);
19+
return countWords(0, text.length(), wordIterator);
20+
}
21+
1422
// setText() should be applied before using this function.
1523
static int countWords(int start, int end, BreakIterator wordIterator) {
1624
assert start < end;

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/chunking/EmbeddingRequestChunker.java

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,26 @@ public record BatchRequestAndListener(BatchRequest batch, ActionListener<Inferen
8181
private ActionListener<List<ChunkedInference>> finalListener;
8282

8383
public EmbeddingRequestChunker(List<ChunkInferenceInput> inputs, int maxNumberOfInputsPerBatch) {
84-
this(inputs, maxNumberOfInputsPerBatch, null);
84+
this(inputs, maxNumberOfInputsPerBatch, true, null);
8585
}
8686

8787
public EmbeddingRequestChunker(List<ChunkInferenceInput> inputs, int maxNumberOfInputsPerBatch, int wordsPerChunk, int chunkOverlap) {
88-
this(inputs, maxNumberOfInputsPerBatch, new WordBoundaryChunkingSettings(wordsPerChunk, chunkOverlap));
88+
this(inputs, maxNumberOfInputsPerBatch, true, new WordBoundaryChunkingSettings(wordsPerChunk, chunkOverlap));
8989
}
9090

9191
public EmbeddingRequestChunker(
9292
List<ChunkInferenceInput> inputs,
9393
int maxNumberOfInputsPerBatch,
9494
ChunkingSettings defaultChunkingSettings
95+
) {
96+
this(inputs, maxNumberOfInputsPerBatch, true, defaultChunkingSettings);
97+
}
98+
99+
public EmbeddingRequestChunker(
100+
List<ChunkInferenceInput> inputs,
101+
int maxNumberOfInputsPerBatch,
102+
boolean batchChunksAcrossInputs,
103+
ChunkingSettings defaultChunkingSettings
95104
) {
96105
this.resultEmbeddings = new ArrayList<>(inputs.size());
97106
this.resultOffsetStarts = new ArrayList<>(inputs.size());
@@ -147,13 +156,23 @@ public EmbeddingRequestChunker(
147156
}
148157
}
149158

150-
AtomicInteger counter = new AtomicInteger();
151-
this.batchRequests = allRequests.stream()
152-
.collect(Collectors.groupingBy(it -> counter.getAndIncrement() / maxNumberOfInputsPerBatch))
153-
.values()
154-
.stream()
155-
.map(BatchRequest::new)
156-
.toList();
159+
if (batchChunksAcrossInputs) {
160+
AtomicInteger counter = new AtomicInteger();
161+
this.batchRequests = allRequests.stream()
162+
.collect(Collectors.groupingBy(it -> counter.getAndIncrement() / maxNumberOfInputsPerBatch))
163+
.values()
164+
.stream()
165+
.map(BatchRequest::new)
166+
.toList();
167+
} else {
168+
assert (maxNumberOfInputsPerBatch >= MAX_CHUNKS);
169+
this.batchRequests = allRequests.stream()
170+
.collect(Collectors.groupingBy(Request::inputIndex))
171+
.values()
172+
.stream()
173+
.map(BatchRequest::new)
174+
.toList();
175+
}
157176
}
158177

159178
/**

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/ChunkerUtilsTests.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import org.elasticsearch.test.ESTestCase;
1313

1414
import java.util.Locale;
15+
import java.util.stream.Collectors;
16+
import java.util.stream.IntStream;
1517

1618
import static org.elasticsearch.xpack.core.inference.chunking.WordBoundaryChunkerTests.TEST_TEXT;
1719

@@ -85,6 +87,12 @@ public void testCountWords_WithSymbols() {
8587
}
8688
}
8789

90+
public void testCountWords_GivenStringCountsAllWords() {
91+
int wordCount = randomIntBetween(1, 100);
92+
var testText = IntStream.range(0, wordCount).mapToObj(i -> "word" + i).collect(Collectors.joining(" ")) + ".";
93+
assertEquals(wordCount, ChunkerUtils.countWords(testText));
94+
}
95+
8896
private int[] sentenceSizes(String text) {
8997
var sentences = text.split("\\.\\s+");
9098
var lengths = new int[sentences.length];

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/EmbeddingRequestChunkerTests.java

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@
2121
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
2222

2323
import java.util.ArrayList;
24+
import java.util.Collections;
2425
import java.util.List;
2526
import java.util.concurrent.atomic.AtomicReference;
27+
import java.util.stream.Collectors;
28+
import java.util.stream.IntStream;
2629

2730
import static org.elasticsearch.inference.InferenceString.DataType.TEXT;
2831
import static org.elasticsearch.inference.InferenceString.toStringList;
@@ -38,6 +41,8 @@
3841

3942
public class EmbeddingRequestChunkerTests extends ESTestCase {
4043

44+
private static final int MAX_BATCH_SIZE = 512;
45+
4146
public void testEmptyInput_WordChunker() {
4247
var batches = new EmbeddingRequestChunker<>(List.of(), 100, 100, 10).batchRequestsWithListeners(testListener());
4348
assertThat(batches, empty());
@@ -943,6 +948,78 @@ public void testMergingListener_Sparse() {
943948
}
944949
}
945950

951+
public void testBatchChunksAcrossInputsIsFalse_DoesNotBatchChunksFromSeparateInputs() {
952+
testBatchChunksAcrossInputs(false, List.of(3, 1, 4));
953+
}
954+
955+
public void testBatchChunksAcrossInputsIsTrue_DoesBatchChunksFromSeparateInputs() {
956+
testBatchChunksAcrossInputs(true, List.of(3, 1, 4));
957+
}
958+
959+
public void testBatchChunksAcrossInputsIsTrue_GeneratesMultipleBatches() {
960+
testBatchChunksAcrossInputs(true, List.of(200, 200, 200));
961+
}
962+
963+
public void testBatchChunksAcrossInputsIsFalseAndBatchesLessThanMaxChunkLimit_ThrowsAssertionError() {
964+
int batchSize = randomIntBetween(1, MAX_BATCH_SIZE - 1);
965+
List<ChunkInferenceInput> inputs = List.of(new ChunkInferenceInput("This is a test sentence with ten words in total. "));
966+
var chunkingSettings = new SentenceBoundaryChunkingSettings(10, 0);
967+
expectThrows(
968+
AssertionError.class,
969+
() -> new EmbeddingRequestChunker<>(inputs, batchSize, false, chunkingSettings).batchRequestsWithListeners(testListener())
970+
);
971+
}
972+
973+
private void testBatchChunksAcrossInputs(boolean batchChunksAcrossInputs, List<Integer> batchSizes) {
974+
int maxChunkSize = 10;
975+
var testSentence = IntStream.range(0, maxChunkSize).mapToObj(i -> "Word" + i).collect(Collectors.joining(" ")) + ".";
976+
var chunkingSettings = new SentenceBoundaryChunkingSettings(maxChunkSize, 0);
977+
var totalBatchSizes = batchSizes.stream().mapToInt(Integer::intValue).sum();
978+
List<ChunkInferenceInput> inputs = batchSizes.stream()
979+
.map(i -> new ChunkInferenceInput(String.join(" ", Collections.nCopies(i, testSentence))))
980+
.toList();
981+
982+
var finalListener = testListener();
983+
List<EmbeddingRequestChunker.BatchRequestAndListener> batches = new EmbeddingRequestChunker<>(
984+
inputs,
985+
MAX_BATCH_SIZE,
986+
batchChunksAcrossInputs,
987+
chunkingSettings
988+
).batchRequestsWithListeners(finalListener);
989+
990+
// If we are batching chunks across inputs, we expect the batches to be filled up to the max batch size.
991+
// Otherwise, we expect one batch per input.
992+
int expectedNumberOfBatches = batchChunksAcrossInputs ? (int) Math.ceil((double) totalBatchSizes / MAX_BATCH_SIZE) : inputs.size();
993+
assertThat(batches, hasSize(expectedNumberOfBatches));
994+
if (batchChunksAcrossInputs) {
995+
for (int i = 0; i < batches.size(); i++) {
996+
var expectedBatchSize = i < batches.size() - 1 ? MAX_BATCH_SIZE : totalBatchSizes - (MAX_BATCH_SIZE * (batches.size() - 1));
997+
assertThat(batches.get(i).batch().inputs().get(), hasSize(expectedBatchSize));
998+
batches.get(i)
999+
.listener()
1000+
.onResponse(
1001+
new DenseEmbeddingFloatResults(
1002+
List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { randomFloatBetween(0, 1, true) }))
1003+
)
1004+
);
1005+
}
1006+
} else {
1007+
for (int i = 0; i < batches.size(); i++) {
1008+
assertThat(batches.get(i).batch().inputs().get(), hasSize(batchSizes.get(i)));
1009+
batches.get(i)
1010+
.listener()
1011+
.onResponse(
1012+
new DenseEmbeddingFloatResults(
1013+
List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { randomFloatBetween(0, 1, true) }))
1014+
)
1015+
);
1016+
}
1017+
}
1018+
1019+
assertNotNull(finalListener.results);
1020+
assertThat(finalListener.results, hasSize(3));
1021+
}
1022+
9461023
public void testListenerErrorsWithWrongNumberOfResponses() {
9471024
List<ChunkInferenceInput> inputs = List.of(
9481025
new ChunkInferenceInput("1st small"),

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType;
4545
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel;
4646
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsServiceSettings;
47+
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
4748
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModel;
4849
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
4950
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@@ -276,9 +277,16 @@ protected void doChunkedInfer(
276277
JinaAIModel jinaaiModel = (JinaAIModel) model;
277278
var actionCreator = new JinaAIActionCreator(getSender(), getServiceComponents());
278279

280+
boolean batchChunksAcrossInputs = true;
281+
if (jinaaiModel.getTaskSettings() instanceof JinaAIEmbeddingsTaskSettings jinaAIEmbeddingsTaskSettings) {
282+
batchChunksAcrossInputs = jinaAIEmbeddingsTaskSettings.getLateChunking() == null
283+
|| jinaAIEmbeddingsTaskSettings.getLateChunking() == false;
284+
}
285+
279286
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
280287
inputs,
281288
EMBEDDING_MAX_BATCH_SIZE,
289+
batchChunksAcrossInputs,
282290
jinaaiModel.getConfigurations().getChunkingSettings()
283291
).batchRequestsWithListeners(listener);
284292

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettings.java

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import java.util.Objects;
2525

2626
import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;
27+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
2728
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
2829
import static org.elasticsearch.xpack.inference.services.jinaai.JinaAIService.VALID_INPUT_TYPE_VALUES;
2930

@@ -36,6 +37,11 @@ public class JinaAIEmbeddingsTaskSettings implements TaskSettings {
3637
public static final String NAME = "jinaai_embeddings_task_settings";
3738
public static final JinaAIEmbeddingsTaskSettings EMPTY_SETTINGS = new JinaAIEmbeddingsTaskSettings((InputType) null);
3839
static final String INPUT_TYPE = "input_type";
40+
static final String LATE_CHUNKING = "late_chunking";
41+
42+
protected static final TransportVersion JINA_AI_CONFIGURABLE_LATE_CHUNKING = TransportVersion.fromName(
43+
"jina_ai_configurable_late_chunking"
44+
);
3945

4046
public static JinaAIEmbeddingsTaskSettings fromMap(Map<String, Object> map) {
4147
if (map == null || map.isEmpty()) {
@@ -53,11 +59,13 @@ public static JinaAIEmbeddingsTaskSettings fromMap(Map<String, Object> map) {
5359
validationException
5460
);
5561

62+
Boolean lateChunking = extractOptionalBoolean(map, LATE_CHUNKING, validationException);
63+
5664
if (validationException.validationErrors().isEmpty() == false) {
5765
throw validationException;
5866
}
5967

60-
return new JinaAIEmbeddingsTaskSettings(inputType);
68+
return new JinaAIEmbeddingsTaskSettings(inputType, lateChunking);
6169
}
6270

6371
/**
@@ -76,8 +84,9 @@ public static JinaAIEmbeddingsTaskSettings of(
7684
JinaAIEmbeddingsTaskSettings requestTaskSettings
7785
) {
7886
var inputTypeToUse = getValidInputType(originalSettings, requestTaskSettings);
87+
var lateChunkingToUse = requestTaskSettings.lateChunking != null ? requestTaskSettings.lateChunking : originalSettings.lateChunking;
7988

80-
return new JinaAIEmbeddingsTaskSettings(inputTypeToUse);
89+
return new JinaAIEmbeddingsTaskSettings(inputTypeToUse, lateChunkingToUse);
8190
}
8291

8392
private static InputType getValidInputType(
@@ -94,14 +103,28 @@ private static InputType getValidInputType(
94103
}
95104

96105
private final InputType inputType;
106+
private final Boolean lateChunking;
97107

98108
public JinaAIEmbeddingsTaskSettings(StreamInput in) throws IOException {
99-
this(in.readOptionalEnum(InputType.class));
109+
this.inputType = in.readOptionalEnum(InputType.class);
110+
111+
if (in.getTransportVersion().supports(JINA_AI_CONFIGURABLE_LATE_CHUNKING)) {
112+
this.lateChunking = in.readOptionalBoolean();
113+
} else {
114+
this.lateChunking = null;
115+
}
116+
}
117+
118+
public JinaAIEmbeddingsTaskSettings(@Nullable InputType inputType, Boolean lateChunking) {
119+
validateInputType(inputType);
120+
this.inputType = inputType;
121+
this.lateChunking = lateChunking;
100122
}
101123

102124
public JinaAIEmbeddingsTaskSettings(@Nullable InputType inputType) {
103125
validateInputType(inputType);
104126
this.inputType = inputType;
127+
this.lateChunking = null;
105128
}
106129

107130
private static void validateInputType(InputType inputType) {
@@ -114,7 +137,7 @@ private static void validateInputType(InputType inputType) {
114137

115138
@Override
116139
public boolean isEmpty() {
117-
return inputType == null;
140+
return inputType == null && lateChunking == null;
118141
}
119142

120143
@Override
@@ -124,6 +147,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
124147
builder.field(INPUT_TYPE, inputType);
125148
}
126149

150+
if (lateChunking != null) {
151+
builder.field(LATE_CHUNKING, lateChunking);
152+
}
153+
127154
builder.endObject();
128155
return builder;
129156
}
@@ -132,6 +159,10 @@ public InputType getInputType() {
132159
return inputType;
133160
}
134161

162+
public Boolean getLateChunking() {
163+
return lateChunking;
164+
}
165+
135166
@Override
136167
public String getWriteableName() {
137168
return NAME;
@@ -145,19 +176,23 @@ public TransportVersion getMinimalSupportedVersion() {
145176
@Override
146177
public void writeTo(StreamOutput out) throws IOException {
147178
out.writeOptionalEnum(inputType);
179+
180+
if (out.getTransportVersion().supports(JINA_AI_CONFIGURABLE_LATE_CHUNKING)) {
181+
out.writeOptionalBoolean(lateChunking);
182+
}
148183
}
149184

150185
@Override
151186
public boolean equals(Object o) {
152187
if (this == o) return true;
153188
if (o == null || getClass() != o.getClass()) return false;
154189
JinaAIEmbeddingsTaskSettings that = (JinaAIEmbeddingsTaskSettings) o;
155-
return Objects.equals(inputType, that.inputType);
190+
return Objects.equals(inputType, that.inputType) && Objects.equals(lateChunking, that.lateChunking);
156191
}
157192

158193
@Override
159194
public int hashCode() {
160-
return Objects.hash(inputType);
195+
return Objects.hash(inputType, lateChunking);
161196
}
162197

163198
@Override

0 commit comments

Comments
 (0)