diff --git a/spring-ai-commons/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java b/spring-ai-commons/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java index a202aac426c..7b2dcb8676e 100644 --- a/spring-ai-commons/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java +++ b/spring-ai-commons/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java @@ -33,11 +33,14 @@ * @author Raphael Yu * @author Christian Tzolov * @author Ricken Bazolo + * @author Seunghwan Jung */ public class TokenTextSplitter extends TextSplitter { private static final int DEFAULT_CHUNK_SIZE = 800; + private static final int DEFAULT_CHUNK_OVERLAP = 50; + private static final int MIN_CHUNK_SIZE_CHARS = 350; private static final int MIN_CHUNK_LENGTH_TO_EMBED = 5; @@ -53,6 +56,9 @@ public class TokenTextSplitter extends TextSplitter { // The target size of each text chunk in tokens private final int chunkSize; + // The overlap size of each text chunk in tokens + private final int chunkOverlap; + // The minimum size of each text chunk in characters private final int minChunkSizeChars; @@ -65,16 +71,20 @@ public class TokenTextSplitter extends TextSplitter { private final boolean keepSeparator; public TokenTextSplitter() { - this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR); + this(DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, + KEEP_SEPARATOR); } public TokenTextSplitter(boolean keepSeparator) { - this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator); + this(DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, + keepSeparator); } - public TokenTextSplitter(int chunkSize, int minChunkSizeChars, int minChunkLengthToEmbed, int maxNumChunks, - boolean keepSeparator) { + public TokenTextSplitter(int chunkSize, int chunkOverlap, int minChunkSizeChars, int minChunkLengthToEmbed, + int maxNumChunks, boolean keepSeparator) { + Assert.isTrue(chunkOverlap < chunkSize, "chunk overlap must be less than chunk size"); this.chunkSize = chunkSize; + this.chunkOverlap = chunkOverlap; this.minChunkSizeChars = minChunkSizeChars; this.minChunkLengthToEmbed = minChunkLengthToEmbed; this.maxNumChunks = maxNumChunks; @@ -87,57 +97,89 @@ public static Builder builder() { @Override protected List splitText(String text) { - return doSplit(text, this.chunkSize); + return doSplit(text, this.chunkSize, this.chunkOverlap); } - protected List doSplit(String text, int chunkSize) { + protected List doSplit(String text, int chunkSize, int chunkOverlap) { if (text == null || text.trim().isEmpty()) { return new ArrayList<>(); } List tokens = getEncodedTokens(text); - List chunks = new ArrayList<>(); - int num_chunks = 0; - while (!tokens.isEmpty() && num_chunks < this.maxNumChunks) { - List chunk = tokens.subList(0, Math.min(chunkSize, tokens.size())); - String chunkText = decodeTokens(chunk); - - // Skip the chunk if it is empty or whitespace - if (chunkText.trim().isEmpty()) { - tokens = tokens.subList(chunk.size(), tokens.size()); - continue; - } - - // Find the last period or punctuation mark in the chunk - int lastPunctuation = Math.max(chunkText.lastIndexOf('.'), Math.max(chunkText.lastIndexOf('?'), - Math.max(chunkText.lastIndexOf('!'), chunkText.lastIndexOf('\n')))); + // If text is smaller than chunk size, return as a single chunk + if (tokens.size() <= chunkSize) { + String processedText = this.keepSeparator ? text.trim() : text.replace(System.lineSeparator(), " ").trim(); - if (lastPunctuation != -1 && lastPunctuation > this.minChunkSizeChars) { - // Truncate the chunk text at the punctuation mark - chunkText = chunkText.substring(0, lastPunctuation + 1); + if (processedText.length() > this.minChunkLengthToEmbed) { + return List.of(processedText); } + return new ArrayList<>(); + } + List chunks = new ArrayList<>(); - String chunkTextToAppend = (this.keepSeparator) ? chunkText.trim() - : chunkText.replace(System.lineSeparator(), " ").trim(); - if (chunkTextToAppend.length() > this.minChunkLengthToEmbed) { - chunks.add(chunkTextToAppend); + int position = 0; + int num_chunks = 0; + while (position < tokens.size() && num_chunks < this.maxNumChunks) { + int chunkEnd = Math.min(position + chunkSize, tokens.size()); + + // Extract tokens for this chunk + List chunkTokens = tokens.subList(position, chunkEnd); + String chunkText = decodeTokens(chunkTokens); + + // Apply sentence boundary optimization + String optimizedText = optimizeChunkBoundary(chunkText); + int optimizedTokenCount = getEncodedTokens(optimizedText).size(); + + // Use optimized chunk + String finalChunkText = optimizedText; + int finalChunkTokenCount = optimizedTokenCount; + + // Advance position with minimum advance guarantee + // This prevents creating a series of mini chunks when boundary optimization + // aggressively shrinks chunks + int naturalAdvance = finalChunkTokenCount - chunkOverlap; + int minAdvance = Math.max(1, (chunkSize - chunkOverlap) / 2); + int advance = Math.max(naturalAdvance, minAdvance); + position += advance; + + // Format according to keepSeparator setting + String formattedChunk = this.keepSeparator ? finalChunkText.trim() + : finalChunkText.replace(System.lineSeparator(), " ").trim(); + + // Add chunk if it meets minimum length + if (formattedChunk.length() > this.minChunkLengthToEmbed) { + chunks.add(formattedChunk); + num_chunks++; } + } - // Remove the tokens corresponding to the chunk text from the remaining tokens - tokens = tokens.subList(getEncodedTokens(chunkText).size(), tokens.size()); + return chunks; + } - num_chunks++; + private String optimizeChunkBoundary(String chunkText) { + if (chunkText.length() <= this.minChunkSizeChars) { + return chunkText; } - // Handle the remaining tokens - if (!tokens.isEmpty()) { - String remaining_text = decodeTokens(tokens).replace(System.lineSeparator(), " ").trim(); - if (remaining_text.length() > this.minChunkLengthToEmbed) { - chunks.add(remaining_text); + // Look for sentence endings: . ! ? \n + int bestCutPoint = -1; + + // Check in reverse order to find the last sentence ending + for (int i = chunkText.length() - 1; i >= this.minChunkSizeChars; i--) { + char c = chunkText.charAt(i); + if (c == '.' || c == '!' || c == '?' || c == '\n') { + bestCutPoint = i + 1; // Include the punctuation + break; } } - return chunks; + // If we found a good cut point, use it + if (bestCutPoint > 0) { + return chunkText.substring(0, bestCutPoint); + } + + // Otherwise return the original chunk + return chunkText; } private List getEncodedTokens(String text) { @@ -156,6 +198,8 @@ public static final class Builder { private int chunkSize = DEFAULT_CHUNK_SIZE; + private int chunkOverlap = DEFAULT_CHUNK_OVERLAP; + private int minChunkSizeChars = MIN_CHUNK_SIZE_CHARS; private int minChunkLengthToEmbed = MIN_CHUNK_LENGTH_TO_EMBED; @@ -172,6 +216,11 @@ public Builder withChunkSize(int chunkSize) { return this; } + public Builder withChunkOverlap(int chunkOverlap) { + this.chunkOverlap = chunkOverlap; + return this; + } + public Builder withMinChunkSizeChars(int minChunkSizeChars) { this.minChunkSizeChars = minChunkSizeChars; return this; @@ -193,8 +242,8 @@ public Builder withKeepSeparator(boolean keepSeparator) { } public TokenTextSplitter build() { - return new TokenTextSplitter(this.chunkSize, this.minChunkSizeChars, this.minChunkLengthToEmbed, - this.maxNumChunks, this.keepSeparator); + return new TokenTextSplitter(this.chunkSize, this.chunkOverlap, this.minChunkSizeChars, + this.minChunkLengthToEmbed, this.maxNumChunks, this.keepSeparator); } } diff --git a/spring-ai-commons/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java b/spring-ai-commons/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java index 96c58f3fa9a..2dca2e39c1f 100644 --- a/spring-ai-commons/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java +++ b/spring-ai-commons/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java @@ -25,9 +25,11 @@ import org.springframework.ai.document.Document; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * @author Ricken Bazolo + * @author Seunghwan Jung */ public class TokenTextSplitterTest { @@ -43,9 +45,9 @@ public void testTokenTextSplitterBuilderWithDefaultValues() { Map.of("key1", "value1", "key2", "value2")); doc1.setContentFormatter(contentFormatter1); - var doc2 = new Document("The most oppressive thing about the labyrinth is that you are constantly " - + "being forced to choose. It isn’t the lack of an exit, but the abundance of exits that is so disorienting.", - Map.of("key2", "value22", "key3", "value3")); + String doc2Text = "The most oppressive thing about the labyrinth is that you are constantly " + + "being forced to choose. It isn't the lack of an exit, but the abundance of exits that is so disorienting."; + var doc2 = new Document(doc2Text, Map.of("key2", "value22", "key3", "value3")); doc2.setContentFormatter(contentFormatter2); var tokenTextSplitter = new TokenTextSplitter(); @@ -54,12 +56,9 @@ public void testTokenTextSplitterBuilderWithDefaultValues() { assertThat(chunks.size()).isEqualTo(2); - // Doc 1 assertThat(chunks.get(0).getText()) .isEqualTo("In the end, writing arises when man realizes that memory is not enough."); - // Doc 2 - assertThat(chunks.get(1).getText()).isEqualTo( - "The most oppressive thing about the labyrinth is that you are constantly being forced to choose. It isn’t the lack of an exit, but the abundance of exits that is so disorienting."); + assertThat(chunks.get(1).getText()).isEqualTo(doc2Text); assertThat(chunks.get(0).getMetadata()).containsKeys("key1", "key2").doesNotContainKeys("key3"); assertThat(chunks.get(1).getMetadata()).containsKeys("key2", "key3").doesNotContainKeys("key1"); @@ -83,46 +82,246 @@ public void testTokenTextSplitterBuilderWithAllFields() { doc2.setContentFormatter(contentFormatter2); var tokenTextSplitter = TokenTextSplitter.builder() - .withChunkSize(10) - .withMinChunkSizeChars(5) - .withMinChunkLengthToEmbed(3) + .withChunkSize(20) + .withChunkOverlap(3) + .withMinChunkSizeChars(10) + .withMinChunkLengthToEmbed(5) .withMaxNumChunks(50) .withKeepSeparator(true) .build(); var chunks = tokenTextSplitter.apply(List.of(doc1, doc2)); - assertThat(chunks.size()).isEqualTo(6); - - // Doc 1 - assertThat(chunks.get(0).getText()).isEqualTo("In the end, writing arises when man realizes that"); - assertThat(chunks.get(1).getText()).isEqualTo("memory is not enough."); - - // Doc 2 - assertThat(chunks.get(2).getText()).isEqualTo("The most oppressive thing about the labyrinth is that you"); - assertThat(chunks.get(3).getText()).isEqualTo("are constantly being forced to choose."); - assertThat(chunks.get(4).getText()).isEqualTo("It isn’t the lack of an exit, but"); - assertThat(chunks.get(5).getText()).isEqualTo("the abundance of exits that is so disorienting"); - - // Verify that the original metadata is copied to all chunks (including - // chunk-specific fields) - assertThat(chunks.get(0).getMetadata()).containsKeys("key1", "key2", "parent_document_id", "chunk_index", - "total_chunks"); - assertThat(chunks.get(1).getMetadata()).containsKeys("key1", "key2", "parent_document_id", "chunk_index", - "total_chunks"); - assertThat(chunks.get(2).getMetadata()).containsKeys("key2", "key3", "parent_document_id", "chunk_index", - "total_chunks"); - assertThat(chunks.get(3).getMetadata()).containsKeys("key2", "key3", "parent_document_id", "chunk_index", - "total_chunks"); - - // Verify chunk indices are correct - assertThat(chunks.get(0).getMetadata().get("chunk_index")).isEqualTo(0); - assertThat(chunks.get(1).getMetadata().get("chunk_index")).isEqualTo(1); - assertThat(chunks.get(2).getMetadata().get("chunk_index")).isEqualTo(0); - assertThat(chunks.get(3).getMetadata().get("chunk_index")).isEqualTo(1); + assertThat(chunks.size()).isBetween(4, 10); - assertThat(chunks.get(0).getMetadata()).containsKeys("key1", "key2").doesNotContainKeys("key3"); - assertThat(chunks.get(2).getMetadata()).containsKeys("key2", "key3").doesNotContainKeys("key1"); + for (Document chunk : chunks) { + assertThat(chunk.getText()).isNotEmpty(); + assertThat(chunk.getText().trim().length()).isGreaterThanOrEqualTo(5); + } + + boolean foundDoc1Chunks = false; + boolean foundDoc2Chunks = false; + + for (Document chunk : chunks) { + Map metadata = chunk.getMetadata(); + + if (metadata.containsKey("key1") && !metadata.containsKey("key3")) { + assertThat(metadata).containsKeys("key1", "key2").doesNotContainKeys("key3"); + foundDoc1Chunks = true; + } + else if (metadata.containsKey("key3") && !metadata.containsKey("key1")) { + assertThat(metadata).containsKeys("key2", "key3").doesNotContainKeys("key1"); + foundDoc2Chunks = true; + } + } + + assertThat(foundDoc1Chunks).isTrue(); + assertThat(foundDoc2Chunks).isTrue(); + + for (Document chunk : chunks) { + assertThat(chunk.getMetadata()).containsKeys("parent_document_id", "chunk_index", "total_chunks"); + } + + int doc1ChunkIndex = 0; + int doc2ChunkIndex = 0; + for (Document chunk : chunks) { + Map metadata = chunk.getMetadata(); + + if (metadata.containsKey("key1") && !metadata.containsKey("key3")) { + assertThat(metadata.get("chunk_index")).isEqualTo(doc1ChunkIndex); + doc1ChunkIndex++; + } + else if (metadata.containsKey("key3") && !metadata.containsKey("key1")) { + assertThat(metadata.get("chunk_index")).isEqualTo(doc2ChunkIndex); + doc2ChunkIndex++; + } + } + } + + @Test + public void testChunkOverlapFunctionality() { + String longText = "This is the first sentence. This is the second sentence. " + + "This is the third sentence. This is the fourth sentence. " + + "This is the fifth sentence. This is the sixth sentence."; + + var doc = new Document(longText); + + var tokenTextSplitter = TokenTextSplitter.builder() + .withChunkSize(15) + .withChunkOverlap(5) + .withMinChunkSizeChars(10) + .withMinChunkLengthToEmbed(5) + .withKeepSeparator(false) + .build(); + + var chunks = tokenTextSplitter.apply(List.of(doc)); + + assertThat(chunks.size()).isGreaterThan(1); + + if (chunks.size() >= 2) { + String firstChunk = chunks.get(0).getText(); + String secondChunk = chunks.get(1).getText(); + + assertThat(firstChunk).isNotEmpty(); + assertThat(secondChunk).isNotEmpty(); + } + } + + @Test + public void testChunkOverlapValidation() { + assertThatThrownBy(() -> TokenTextSplitter.builder().withChunkSize(10).withChunkOverlap(15).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("chunk overlap must be less than chunk size"); + + assertThatThrownBy(() -> TokenTextSplitter.builder().withChunkSize(10).withChunkOverlap(10).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("chunk overlap must be less than chunk size"); + } + + @Test + public void testBoundaryOptimizationWithOverlap() { + String text = "First sentence here. Second sentence follows immediately. " + + "Third sentence is next. Fourth sentence continues the text. " + + "Fifth sentence completes this test."; + + var doc = new Document(text); + + var tokenTextSplitter = TokenTextSplitter.builder() + .withChunkSize(20) + .withChunkOverlap(3) + .withMinChunkSizeChars(20) + .withMinChunkLengthToEmbed(5) + .withKeepSeparator(true) + .build(); + + var chunks = tokenTextSplitter.apply(List.of(doc)); + + assertThat(chunks).isNotEmpty(); + + for (Document chunk : chunks) { + String chunkText = chunk.getText(); + if (chunkText != null && chunkText.trim().length() > 20) { + assertThat(chunkText.trim()).isNotEmpty(); + } + } + } + + @Test + public void testKeepSeparatorVariations() { + String textWithNewlines = "Line one content here.\nLine two content here.\nLine three content here."; + var doc = new Document(textWithNewlines); + + var splitterKeepSeparator = TokenTextSplitter.builder() + .withChunkSize(50) + .withChunkOverlap(0) + .withKeepSeparator(true) + .build(); + + var chunksWithSeparator = splitterKeepSeparator.apply(List.of(doc)); + + var splitterNoSeparator = TokenTextSplitter.builder() + .withChunkSize(50) + .withChunkOverlap(0) + .withKeepSeparator(false) + .build(); + + var chunksWithoutSeparator = splitterNoSeparator.apply(List.of(doc)); + + assertThat(chunksWithSeparator).isNotEmpty(); + assertThat(chunksWithoutSeparator).isNotEmpty(); + + if (chunksWithSeparator.size() == 1 && chunksWithoutSeparator.size() == 1) { + String withSeparatorText = chunksWithSeparator.get(0).getText(); + String withoutSeparatorText = chunksWithoutSeparator.get(0).getText(); + + assertThat(withSeparatorText).contains("\n"); + assertThat(withoutSeparatorText).doesNotContain("\n"); + } + } + + @Test + public void testNoMiniChunksAtEnd() { + StringBuilder longText = new StringBuilder(); + for (int i = 0; i < 100; i++) { + longText.append("This is sentence number ") + .append(i) + .append(" and it contains some meaningful content to test the chunking behavior. "); + } + + var doc = new Document(longText.toString()); + + var tokenTextSplitter = TokenTextSplitter.builder() + .withChunkSize(100) + .withChunkOverlap(10) + .withMinChunkSizeChars(50) + .withMinChunkLengthToEmbed(5) + .withKeepSeparator(true) + .build(); + + var chunks = tokenTextSplitter.apply(List.of(doc)); + + assertThat(chunks.size()).isGreaterThan(1); + + var encoding = com.knuddels.jtokkit.Encodings.newDefaultEncodingRegistry() + .getEncoding(com.knuddels.jtokkit.api.EncodingType.CL100K_BASE); + + int minExpectedAdvance = (100 - 10) / 2; + int consecutiveSmallChunks = 0; + int maxConsecutiveSmallChunks = 0; + + for (int i = 0; i < chunks.size() - 1; i++) { + String chunkText = chunks.get(i).getText(); + int tokenCount = encoding.encode(chunkText).size(); + + if (tokenCount < minExpectedAdvance) { + consecutiveSmallChunks++; + maxConsecutiveSmallChunks = Math.max(maxConsecutiveSmallChunks, consecutiveSmallChunks); + } + else { + consecutiveSmallChunks = 0; + } + + assertThat(tokenCount) + .as("Chunk %d should have at least %d tokens but has %d", i, minExpectedAdvance, tokenCount) + .isGreaterThanOrEqualTo(minExpectedAdvance); + } + + assertThat(maxConsecutiveSmallChunks) + .as("Should not have multiple consecutive small chunks (found %d consecutive)", maxConsecutiveSmallChunks) + .isLessThanOrEqualTo(1); + } + + @Test + public void testChunkSizesAreConsistent() { + StringBuilder text = new StringBuilder(); + for (int i = 0; i < 50; i++) { + text.append("Sentence ").append(i).append(" contains important information for testing. "); + } + + var doc = new Document(text.toString()); + + var tokenTextSplitter = TokenTextSplitter.builder() + .withChunkSize(80) + .withChunkOverlap(10) + .withMinChunkSizeChars(100) + .withMinChunkLengthToEmbed(5) + .withKeepSeparator(false) + .build(); + + var chunks = tokenTextSplitter.apply(List.of(doc)); + + assertThat(chunks.size()).isGreaterThan(1); + + var encoding = com.knuddels.jtokkit.Encodings.newDefaultEncodingRegistry() + .getEncoding(com.knuddels.jtokkit.api.EncodingType.CL100K_BASE); + + for (int i = 0; i < chunks.size() - 1; i++) { + int tokenCount = encoding.encode(chunks.get(i).getText()).size(); + assertThat(tokenCount).as("Chunk %d token count should be reasonable", i).isBetween(40, 120); + } + + int lastChunkTokens = encoding.encode(chunks.get(chunks.size() - 1).getText()).size(); + assertThat(lastChunkTokens).isGreaterThan(0); } }