Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,37 +29,49 @@
/**
* @author Raphael Yu
* @author Christian Tzolov
* @author Ricken Bazolo
*/
public class TokenTextSplitter extends TextSplitter {

private final EncodingRegistry registry = Encodings.newLazyEncodingRegistry();

private final Encoding encoding = registry.getEncoding(EncodingType.CL100K_BASE);

private final static int DEFAULT_CHUNK_SIZE = 800;

private final static int MIN_CHUNK_SIZE_CHARS = 350;

private final static int MIN_CHUNK_LENGTH_TO_EMBED = 5;

private final static int MAX_NUM_CHUNKS = 10000;

private final static boolean KEEP_SEPARATOR = true;

// The target size of each text chunk in tokens
private int defaultChunkSize = 800;
private final int chunkSize;

// The minimum size of each text chunk in characters
private int minChunkSizeChars = 350;
private final int minChunkSizeChars;

// Discard chunks shorter than this
private int minChunkLengthToEmbed = 5;
private final int minChunkLengthToEmbed;

// The maximum number of chunks to generate from a text
private int maxNumChunks = 10000;
private final int maxNumChunks;

private boolean keepSeparator = true;
private final boolean keepSeparator;

public TokenTextSplitter() {
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR);
}

public TokenTextSplitter(boolean keepSeparator) {
this.keepSeparator = keepSeparator;
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator);
}

public TokenTextSplitter(int defaultChunkSize, int minChunkSizeChars, int minChunkLengthToEmbed, int maxNumChunks,
public TokenTextSplitter(int chunkSize, int minChunkSizeChars, int minChunkLengthToEmbed, int maxNumChunks,
boolean keepSeparator) {
this.defaultChunkSize = defaultChunkSize;
this.chunkSize = chunkSize;
this.minChunkSizeChars = minChunkSizeChars;
this.minChunkLengthToEmbed = minChunkLengthToEmbed;
this.maxNumChunks = maxNumChunks;
Expand All @@ -68,7 +80,7 @@ public TokenTextSplitter(int defaultChunkSize, int minChunkSizeChars, int minChu

@Override
protected List<String> splitText(String text) {
return doSplit(text, this.defaultChunkSize);
return doSplit(text, this.chunkSize);
}

protected List<String> doSplit(String text, int chunkSize) {
Expand Down Expand Up @@ -133,4 +145,55 @@ private String decodeTokens(List<Integer> tokens) {
return this.encoding.decode(tokensIntArray);
}

public static Builder builder() {
return new Builder();
}

public static class Builder {

private int chunkSize;

private int minChunkSizeChars;

private int minChunkLengthToEmbed;

private int maxNumChunks;

private boolean keepSeparator;

public Builder() {
}

public Builder chunkSize(int chunkSize) {
this.chunkSize = chunkSize;
return this;
}

public Builder minChunkSizeChars(int minChunkSizeChars) {
this.minChunkSizeChars = minChunkSizeChars;
return this;
}

public Builder minChunkLengthToEmbed(int minChunkLengthToEmbed) {
this.minChunkLengthToEmbed = minChunkLengthToEmbed;
return this;
}

public Builder maxNumChunks(int maxNumChunks) {
this.maxNumChunks = maxNumChunks;
return this;
}

public Builder keepSeparator(boolean keepSeparator) {
this.keepSeparator = keepSeparator;
return this;
}

public TokenTextSplitter build() {
return new TokenTextSplitter(this.chunkSize, this.minChunkSizeChars, this.minChunkLengthToEmbed,
this.maxNumChunks, this.keepSeparator);
}

}

}