diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java b/spring-ai-core/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java index cc034b49cbe..420e7a28772 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java @@ -29,6 +29,7 @@ /** * @author Raphael Yu * @author Christian Tzolov + * @author Ricken Bazolo */ public class TokenTextSplitter extends TextSplitter { @@ -36,30 +37,41 @@ public class TokenTextSplitter extends TextSplitter { private final Encoding encoding = registry.getEncoding(EncodingType.CL100K_BASE); + private final static int DEFAULT_CHUNK_SIZE = 800; + + private final static int MIN_CHUNK_SIZE_CHARS = 350; + + private final static int MIN_CHUNK_LENGTH_TO_EMBED = 5; + + private final static int MAX_NUM_CHUNKS = 10000; + + private final static boolean KEEP_SEPARATOR = true; + // The target size of each text chunk in tokens - private int defaultChunkSize = 800; + private final int chunkSize; // The minimum size of each text chunk in characters - private int minChunkSizeChars = 350; + private final int minChunkSizeChars; // Discard chunks shorter than this - private int minChunkLengthToEmbed = 5; + private final int minChunkLengthToEmbed; // The maximum number of chunks to generate from a text - private int maxNumChunks = 10000; + private final int maxNumChunks; - private boolean keepSeparator = true; + private final boolean keepSeparator; public TokenTextSplitter() { + this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR); } public TokenTextSplitter(boolean keepSeparator) { - this.keepSeparator = keepSeparator; + this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator); } - public TokenTextSplitter(int defaultChunkSize, int minChunkSizeChars, int minChunkLengthToEmbed, int maxNumChunks, + public TokenTextSplitter(int chunkSize, int minChunkSizeChars, int minChunkLengthToEmbed, int maxNumChunks, boolean keepSeparator) { - this.defaultChunkSize = defaultChunkSize; + this.chunkSize = chunkSize; this.minChunkSizeChars = minChunkSizeChars; this.minChunkLengthToEmbed = minChunkLengthToEmbed; this.maxNumChunks = maxNumChunks; @@ -68,7 +80,7 @@ public TokenTextSplitter(int defaultChunkSize, int minChunkSizeChars, int minChu @Override protected List splitText(String text) { - return doSplit(text, this.defaultChunkSize); + return doSplit(text, this.chunkSize); } protected List doSplit(String text, int chunkSize) { @@ -133,4 +145,55 @@ private String decodeTokens(List tokens) { return this.encoding.decode(tokensIntArray); } + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private int chunkSize; + + private int minChunkSizeChars; + + private int minChunkLengthToEmbed; + + private int maxNumChunks; + + private boolean keepSeparator; + + private Builder() { + } + + public Builder withChunkSize(int chunkSize) { + this.chunkSize = chunkSize; + return this; + } + + public Builder withMinChunkSizeChars(int minChunkSizeChars) { + this.minChunkSizeChars = minChunkSizeChars; + return this; + } + + public Builder withMinChunkLengthToEmbed(int minChunkLengthToEmbed) { + this.minChunkLengthToEmbed = minChunkLengthToEmbed; + return this; + } + + public Builder withMaxNumChunks(int maxNumChunks) { + this.maxNumChunks = maxNumChunks; + return this; + } + + public Builder withKeepSeparator(boolean keepSeparator) { + this.keepSeparator = keepSeparator; + return this; + } + + public TokenTextSplitter build() { + return new TokenTextSplitter(this.chunkSize, this.minChunkSizeChars, this.minChunkLengthToEmbed, + this.maxNumChunks, this.keepSeparator); + } + + } + } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java b/spring-ai-core/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java new file mode 100644 index 00000000000..0baefc0acb9 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java @@ -0,0 +1,98 @@ +package org.springframework.ai.transformer.splitter; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.document.DefaultContentFormatter; +import org.springframework.ai.document.Document; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Ricken Bazolo + */ +public class TokenTextSplitterTest { + + @Test + public void testTokenTextSplitterBuilderWithDefaultValues() { + + var contentFormatter1 = DefaultContentFormatter.defaultConfig(); + var contentFormatter2 = DefaultContentFormatter.defaultConfig(); + + assertThat(contentFormatter1).isNotSameAs(contentFormatter2); + + var doc1 = new Document("In the end, writing arises when man realizes that memory is not enough.", + Map.of("key1", "value1", "key2", "value2")); + doc1.setContentFormatter(contentFormatter1); + + var doc2 = new Document("The most oppressive thing about the labyrinth is that you are constantly " + + "being forced to choose. It isn’t the lack of an exit, but the abundance of exits that is so disorienting.", + Map.of("key2", "value22", "key3", "value3")); + doc2.setContentFormatter(contentFormatter2); + + var tokenTextSplitter = new TokenTextSplitter(); + + var chunks = tokenTextSplitter.apply(List.of(doc1, doc2)); + + assertThat(chunks.size()).isEqualTo(2); + + // Doc 1 + assertThat(chunks.get(0).getContent()) + .isEqualTo("In the end, writing arises when man realizes that memory is not enough."); + // Doc 2 + assertThat(chunks.get(1).getContent()).isEqualTo( + "The most oppressive thing about the labyrinth is that you are constantly being forced to choose. It isn’t the lack of an exit, but the abundance of exits that is so disorienting."); + + assertThat(chunks.get(0).getMetadata()).containsKeys("key1", "key2").doesNotContainKeys("key3"); + assertThat(chunks.get(1).getMetadata()).containsKeys("key2", "key3").doesNotContainKeys("key1"); + } + + @Test + public void testTokenTextSplitterBuilderWithAllFields() { + + var contentFormatter1 = DefaultContentFormatter.defaultConfig(); + var contentFormatter2 = DefaultContentFormatter.defaultConfig(); + + assertThat(contentFormatter1).isNotSameAs(contentFormatter2); + + var doc1 = new Document("In the end, writing arises when man realizes that memory is not enough.", + Map.of("key1", "value1", "key2", "value2")); + doc1.setContentFormatter(contentFormatter1); + + var doc2 = new Document("The most oppressive thing about the labyrinth is that you are constantly " + + "being forced to choose. It isn’t the lack of an exit, but the abundance of exits that is so disorienting.", + Map.of("key2", "value22", "key3", "value3")); + doc2.setContentFormatter(contentFormatter2); + + var tokenTextSplitter = TokenTextSplitter.builder() + .withChunkSize(10) + .withMinChunkSizeChars(5) + .withMinChunkLengthToEmbed(3) + .withMaxNumChunks(50) + .withKeepSeparator(true) + .build(); + + var chunks = tokenTextSplitter.apply(List.of(doc1, doc2)); + + assertThat(chunks.size()).isEqualTo(6); + + // Doc 1 + assertThat(chunks.get(0).getContent()).isEqualTo("In the end, writing arises when man realizes that"); + assertThat(chunks.get(1).getContent()).isEqualTo("memory is not enough."); + + // Doc 2 + assertThat(chunks.get(2).getContent()).isEqualTo("The most oppressive thing about the labyrinth is that you"); + assertThat(chunks.get(3).getContent()).isEqualTo("are constantly being forced to choose."); + assertThat(chunks.get(4).getContent()).isEqualTo("It isn’t the lack of an exit, but"); + assertThat(chunks.get(5).getContent()).isEqualTo("the abundance of exits that is so disorienting"); + + // Verify that the same, merged metadata is copied to all chunks. + assertThat(chunks.get(0).getMetadata()).isEqualTo(chunks.get(1).getMetadata()); + assertThat(chunks.get(2).getMetadata()).isEqualTo(chunks.get(3).getMetadata()); + + assertThat(chunks.get(0).getMetadata()).containsKeys("key1", "key2").doesNotContainKeys("key3"); + assertThat(chunks.get(2).getMetadata()).containsKeys("key2", "key3").doesNotContainKeys("key1"); + } + +}