From e7c688f0a2cc89c476a749b615f29516e8e58721 Mon Sep 17 00:00:00 2001 From: Soby Chacko Date: Mon, 11 Nov 2024 19:09:02 -0500 Subject: [PATCH] GH-1240: Add builder pattern to ChromaVectorStore for better initialization control Fixes: #1240 Issue: https://github.com/spring-projects/spring-ai/issues/1240 The change addresses initialization issues when ChromaVectorStore is used outside Spring context, particularly in scenarios where collections are created manually before store instantiation. Previously, collection ID wasn't properly populated when afterPropertiesSet() wasn't called by Spring container. - Add builder pattern to ChromaVectorStore for better initialization control - Add initialization flag to prevent multiple collection creation calls - Add integration tests for builder pattern usage scenarios - Add spring-ai-transformers dependency for testing - Remove unused constants (SIMILARITY_THRESHOLD_ALL, DEFAULT_TOP_K) Collection ID is now properly set regardless of whether the store is managed by Spring or created manually, solving the 404 Not Found errors during document insertion. --- vector-stores/spring-ai-chroma-store/pom.xml | 7 + .../ai/vectorstore/ChromaVectorStore.java | 120 ++++++++++++++--- .../ai/chroma/ChromaApiIT.java | 124 +++++++++++++----- 3 files changed, 204 insertions(+), 47 deletions(-) diff --git a/vector-stores/spring-ai-chroma-store/pom.xml b/vector-stores/spring-ai-chroma-store/pom.xml index 3e7b35c4a8b..88f6738c4af 100644 --- a/vector-stores/spring-ai-chroma-store/pom.xml +++ b/vector-stores/spring-ai-chroma-store/pom.xml @@ -87,6 +87,13 @@ micrometer-observation-test test + + + org.springframework.ai + spring-ai-transformers + ${parent.version} + test + diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java index 90899516ae9..ba1166927c6 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java @@ -41,7 +41,6 @@ import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext.Builder; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; import org.springframework.beans.factory.InitializingBean; import org.springframework.util.Assert; @@ -58,7 +57,7 @@ * @author Christian Tzolov * @author Fu Cheng * @author Sebastien Deleuze - * + * @author Soby Chacko */ public class ChromaVectorStore extends AbstractObservationVectorStore implements InitializingBean { @@ -66,10 +65,6 @@ public class ChromaVectorStore extends AbstractObservationVectorStore implements public static final String DEFAULT_COLLECTION_NAME = "SpringAiCollection"; - public static final double SIMILARITY_THRESHOLD_ALL = 0.0; - - public static final int DEFAULT_TOP_K = 4; - private final EmbeddingModel embeddingModel; private final ChromaApi chromaApi; @@ -86,6 +81,8 @@ public class ChromaVectorStore extends AbstractObservationVectorStore implements private final ObjectMapper objectMapper; + private boolean initialized = false; + public ChromaVectorStore(EmbeddingModel embeddingModel, ChromaApi chromaApi, boolean initializeSchema) { this(embeddingModel, chromaApi, DEFAULT_COLLECTION_NAME, initializeSchema); } @@ -111,6 +108,26 @@ public ChromaVectorStore(EmbeddingModel embeddingModel, ChromaApi chromaApi, Str this.objectMapper = JsonMapper.builder().addModules(JacksonUtils.instantiateAvailableModules()).build(); } + private ChromaVectorStore(Builder builder) { + super(builder.observationRegistry, builder.customObservationConvention); + this.embeddingModel = builder.embeddingModel; + this.chromaApi = builder.chromaApi; + this.collectionName = builder.collectionName; + this.initializeSchema = builder.initializeSchema; + this.filterExpressionConverter = builder.filterExpressionConverter; + this.batchingStrategy = builder.batchingStrategy; + this.objectMapper = JsonMapper.builder().addModules(JacksonUtils.instantiateAvailableModules()).build(); + + if (builder.initializeImmediately) { + try { + afterPropertiesSet(); + } + catch (Exception e) { + throw new IllegalStateException("Failed to initialize ChromaVectorStore", e); + } + } + } + public void setFilterExpressionConverter(FilterExpressionConverter filterExpressionConverter) { Assert.notNull(filterExpressionConverter, "FilterExpressionConverter should not be null."); this.filterExpressionConverter = filterExpressionConverter; @@ -207,26 +224,95 @@ public String getCollectionId() { @Override public void afterPropertiesSet() throws Exception { - var collection = this.chromaApi.getCollection(this.collectionName); - if (collection == null) { - if (this.initializeSchema) { - collection = this.chromaApi - .createCollection(new ChromaApi.CreateCollectionRequest(this.collectionName)); - } - else { - throw new RuntimeException("Collection " + this.collectionName - + " doesn't exist and won't be created as the initializeSchema is set to false."); + if (!this.initialized) { + var collection = this.chromaApi.getCollection(this.collectionName); + if (collection == null) { + if (this.initializeSchema) { + collection = this.chromaApi + .createCollection(new ChromaApi.CreateCollectionRequest(this.collectionName)); + } + else { + throw new RuntimeException("Collection " + this.collectionName + + " doesn't exist and won't be created as the initializeSchema is set to false."); + } } + this.collectionId = collection.id(); + this.initialized = true; } - this.collectionId = collection.id(); } @Override - public Builder createObservationContextBuilder(String operationName) { + public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) { return VectorStoreObservationContext.builder(VectorStoreProvider.CHROMA.value(), operationName) .withDimensions(this.embeddingModel.dimensions()) .withCollectionName(this.collectionName + ":" + this.collectionId) .withFieldName(this.initializeSchema ? DISTANCE_FIELD_NAME : null); } + public static class Builder { + + private final EmbeddingModel embeddingModel; + + private final ChromaApi chromaApi; + + private String collectionName = DEFAULT_COLLECTION_NAME; + + private boolean initializeSchema = false; + + private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; + + private VectorStoreObservationConvention customObservationConvention = null; + + private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy(); + + private FilterExpressionConverter filterExpressionConverter = new ChromaFilterExpressionConverter(); + + private boolean initializeImmediately = false; + + public Builder(EmbeddingModel embeddingModel, ChromaApi chromaApi) { + this.embeddingModel = embeddingModel; + this.chromaApi = chromaApi; + } + + public Builder collectionName(String collectionName) { + this.collectionName = collectionName; + return this; + } + + public Builder initializeSchema(boolean initializeSchema) { + this.initializeSchema = initializeSchema; + return this; + } + + public Builder observationRegistry(ObservationRegistry observationRegistry) { + this.observationRegistry = observationRegistry; + return this; + } + + public Builder customObservationConvention(VectorStoreObservationConvention convention) { + this.customObservationConvention = convention; + return this; + } + + public Builder batchingStrategy(BatchingStrategy batchingStrategy) { + this.batchingStrategy = batchingStrategy; + return this; + } + + public Builder filterExpressionConverter(FilterExpressionConverter converter) { + this.filterExpressionConverter = converter; + return this; + } + + public Builder initializeImmediately(boolean initialize) { + this.initializeImmediately = initialize; + return this; + } + + public ChromaVectorStore build() { + return new ChromaVectorStore(this); + } + + } + } diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java index bfc84340294..0b01ae35486 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java @@ -16,6 +16,7 @@ package org.springframework.ai.chroma; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -30,17 +31,24 @@ import org.springframework.ai.chroma.ChromaApi.Collection; import org.springframework.ai.chroma.ChromaApi.GetEmbeddingsRequest; import org.springframework.ai.chroma.ChromaApi.QueryRequest; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.transformers.TransformersEmbeddingModel; +import org.springframework.ai.vectorstore.ChromaVectorStore; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatNoException; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; /** * @author Christian Tzolov * @author EddĂș MelĂ©ndez * @author Thomas Vitale + * @author Soby Chacko */ @SpringBootTest @Testcontainers @@ -50,17 +58,20 @@ public class ChromaApiIT { static ChromaDBContainer chromaContainer = new ChromaDBContainer(ChromaImage.DEFAULT_IMAGE); @Autowired - ChromaApi chroma; + ChromaApi chromaApi; + + @Autowired + EmbeddingModel embeddingModel; @BeforeEach public void beforeEach() { - this.chroma.listCollections().stream().forEach(c -> this.chroma.deleteCollection(c.name())); + this.chromaApi.listCollections().stream().forEach(c -> this.chromaApi.deleteCollection(c.name())); } @Test public void testClientWithMetadata() { Map metadata = Map.of("hnsw:space", "cosine", "hnsw:M", 5); - var newCollection = this.chroma + var newCollection = this.chromaApi .createCollection(new ChromaApi.CreateCollectionRequest("TestCollection", metadata)); assertThat(newCollection).isNotNull(); assertThat(newCollection.name()).isEqualTo("TestCollection"); @@ -68,44 +79,44 @@ public void testClientWithMetadata() { @Test public void testClient() { - var newCollection = this.chroma.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); + var newCollection = this.chromaApi.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); assertThat(newCollection).isNotNull(); assertThat(newCollection.name()).isEqualTo("TestCollection"); - var getCollection = this.chroma.getCollection("TestCollection"); + var getCollection = this.chromaApi.getCollection("TestCollection"); assertThat(getCollection).isNotNull(); assertThat(getCollection.name()).isEqualTo("TestCollection"); assertThat(getCollection.id()).isEqualTo(newCollection.id()); - List collections = this.chroma.listCollections(); + List collections = this.chromaApi.listCollections(); assertThat(collections).hasSize(1); assertThat(collections.get(0).id()).isEqualTo(newCollection.id()); - this.chroma.deleteCollection(newCollection.name()); - assertThat(this.chroma.listCollections()).hasSize(0); + this.chromaApi.deleteCollection(newCollection.name()); + assertThat(this.chromaApi.listCollections()).hasSize(0); } @Test public void testCollection() { - var newCollection = this.chroma.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); - assertThat(this.chroma.countEmbeddings(newCollection.id())).isEqualTo(0); + var newCollection = this.chromaApi.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); + assertThat(this.chromaApi.countEmbeddings(newCollection.id())).isEqualTo(0); var addEmbeddingRequest = new AddEmbeddingsRequest(List.of("id1", "id2"), List.of(new float[] { 1f, 1f, 1f }, new float[] { 2f, 2f, 2f }), List.of(Map.of(), Map.of("key1", "value1", "key2", true, "key3", 23.4)), List.of("Hello World", "Big World")); - this.chroma.upsertEmbeddings(newCollection.id(), addEmbeddingRequest); + this.chromaApi.upsertEmbeddings(newCollection.id(), addEmbeddingRequest); var addEmbeddingRequest2 = new AddEmbeddingsRequest("id3", new float[] { 3f, 3f, 3f }, Map.of("key1", "value1", "key2", true, "key3", 23.4), "Big World"); - this.chroma.upsertEmbeddings(newCollection.id(), addEmbeddingRequest2); + this.chromaApi.upsertEmbeddings(newCollection.id(), addEmbeddingRequest2); - assertThat(this.chroma.countEmbeddings(newCollection.id())).isEqualTo(3); + assertThat(this.chromaApi.countEmbeddings(newCollection.id())).isEqualTo(3); - var queryResult = this.chroma.queryCollection(newCollection.id(), - new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chroma.where(""" + var queryResult = this.chromaApi.queryCollection(newCollection.id(), + new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chromaApi.where(""" { "key2" : { "$eq": true } } @@ -114,14 +125,14 @@ public void testCollection() { assertThat(queryResult.ids().get(0)).containsExactlyInAnyOrder("id2", "id3"); // Update existing embedding. - this.chroma.upsertEmbeddings(newCollection.id(), new AddEmbeddingsRequest("id3", new float[] { 6f, 6f, 6f }, + this.chromaApi.upsertEmbeddings(newCollection.id(), new AddEmbeddingsRequest("id3", new float[] { 6f, 6f, 6f }, Map.of("key1", "value2", "key2", false, "key4", 23.4), "Small World")); - var result = this.chroma.getEmbeddings(newCollection.id(), new GetEmbeddingsRequest(List.of("id2"))); + var result = this.chromaApi.getEmbeddings(newCollection.id(), new GetEmbeddingsRequest(List.of("id2"))); assertThat(result.ids().get(0)).isEqualTo("id2"); - queryResult = this.chroma.queryCollection(newCollection.id(), - new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chroma.where(""" + queryResult = this.chromaApi.queryCollection(newCollection.id(), + new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chromaApi.where(""" { "key2" : { "$eq": true } } @@ -133,7 +144,7 @@ public void testCollection() { @Test public void testQueryWhere() { - var collection = this.chroma.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); + var collection = this.chromaApi.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); var add1 = new AddEmbeddingsRequest("id1", new float[] { 1f, 1f, 1f }, Map.of("country", "BG", "active", true, "price", 23.4, "year", 2020), @@ -146,24 +157,25 @@ public void testQueryWhere() { Map.of("country", "BG", "active", false, "price", 40.1, "year", 2023), "The World is Big and Salvation Lurks Around the Corner"); - this.chroma.upsertEmbeddings(collection.id(), add1); - this.chroma.upsertEmbeddings(collection.id(), add2); - this.chroma.upsertEmbeddings(collection.id(), add3); + this.chromaApi.upsertEmbeddings(collection.id(), add1); + this.chromaApi.upsertEmbeddings(collection.id(), add2); + this.chromaApi.upsertEmbeddings(collection.id(), add3); - assertThat(this.chroma.countEmbeddings(collection.id())).isEqualTo(3); + assertThat(this.chromaApi.countEmbeddings(collection.id())).isEqualTo(3); - var queryResult = this.chroma.queryCollection(collection.id(), new QueryRequest(new float[] { 1f, 1f, 1f }, 3)); + var queryResult = this.chromaApi.queryCollection(collection.id(), + new QueryRequest(new float[] { 1f, 1f, 1f }, 3)); assertThat(queryResult.ids().get(0)).hasSize(3); assertThat(queryResult.ids().get(0)).containsExactlyInAnyOrder("id1", "id2", "id3"); - var chromaEmbeddings = this.chroma.toEmbeddingResponseList(queryResult); + var chromaEmbeddings = this.chromaApi.toEmbeddingResponseList(queryResult); assertThat(chromaEmbeddings).hasSize(3); assertThat(chromaEmbeddings).hasSize(3); - queryResult = this.chroma.queryCollection(collection.id(), - new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chroma.where(""" + queryResult = this.chromaApi.queryCollection(collection.id(), + new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chromaApi.where(""" { "$and" : [ {"country" : { "$eq": "BG"}}, @@ -174,8 +186,8 @@ public void testQueryWhere() { assertThat(queryResult.ids().get(0)).hasSize(2); assertThat(queryResult.ids().get(0)).containsExactlyInAnyOrder("id1", "id3"); - queryResult = this.chroma.queryCollection(collection.id(), - new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chroma.where(""" + queryResult = this.chromaApi.queryCollection(collection.id(), + new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chromaApi.where(""" { "$and" : [ {"country" : { "$eq": "BG"}}, @@ -188,6 +200,53 @@ public void testQueryWhere() { assertThat(queryResult.ids().get(0)).containsExactlyInAnyOrder("id1"); } + @Test + void shouldUseExistingCollectionWhenSchemaInitializationDisabled() { // initializeSchema + // is false by + // default. + var collection = this.chromaApi.createCollection(new ChromaApi.CreateCollectionRequest("test-collection")); + assertThat(collection).isNotNull(); + assertThat(collection.name()).isEqualTo("test-collection"); + + ChromaVectorStore store = new ChromaVectorStore.Builder(this.embeddingModel, this.chromaApi) + .collectionName("test-collection") + .initializeImmediately(true) + .build(); + + Document document = new Document("test content"); + assertThatNoException().isThrownBy(() -> store.add(Collections.singletonList(document))); + } + + @Test + void shouldCreateNewCollectionWhenSchemaInitializationEnabled() { + ChromaVectorStore store = new ChromaVectorStore.Builder(this.embeddingModel, this.chromaApi) + .collectionName("new-collection") + .initializeSchema(true) + .initializeImmediately(true) + .build(); + + var collection = this.chromaApi.getCollection("new-collection"); + assertThat(collection).isNotNull(); + assertThat(collection.name()).isEqualTo("new-collection"); + + Document document = new Document("test content"); + assertThatNoException().isThrownBy(() -> store.add(Collections.singletonList(document))); + } + + @Test + void shouldFailWhenCollectionDoesNotExist() { + assertThatThrownBy( + () -> new ChromaVectorStore.Builder(this.embeddingModel, this.chromaApi).collectionName("non-existent") + .initializeSchema(false) + .initializeImmediately(true) + .build()) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Failed to initialize ChromaVectorStore") + .hasCauseInstanceOf(RuntimeException.class) + .hasRootCauseMessage( + "Collection non-existent doesn't exist and won't be created as the initializeSchema is set to false."); + } + @SpringBootConfiguration public static class Config { @@ -196,6 +255,11 @@ public ChromaApi chromaApi() { return new ChromaApi(chromaContainer.getEndpoint()); } + @Bean + public EmbeddingModel embeddingModel() { + return new TransformersEmbeddingModel(); + } + } }