diff --git a/vector-stores/spring-ai-chroma-store/pom.xml b/vector-stores/spring-ai-chroma-store/pom.xml
index 3e7b35c4a8b..88f6738c4af 100644
--- a/vector-stores/spring-ai-chroma-store/pom.xml
+++ b/vector-stores/spring-ai-chroma-store/pom.xml
@@ -87,6 +87,13 @@
micrometer-observation-test
test
+
+
+ org.springframework.ai
+ spring-ai-transformers
+ ${parent.version}
+ test
+
diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java
index 90899516ae9..ba1166927c6 100644
--- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java
+++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java
@@ -41,7 +41,6 @@
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
-import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext.Builder;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.util.Assert;
@@ -58,7 +57,7 @@
* @author Christian Tzolov
* @author Fu Cheng
* @author Sebastien Deleuze
- *
+ * @author Soby Chacko
*/
public class ChromaVectorStore extends AbstractObservationVectorStore implements InitializingBean {
@@ -66,10 +65,6 @@ public class ChromaVectorStore extends AbstractObservationVectorStore implements
public static final String DEFAULT_COLLECTION_NAME = "SpringAiCollection";
- public static final double SIMILARITY_THRESHOLD_ALL = 0.0;
-
- public static final int DEFAULT_TOP_K = 4;
-
private final EmbeddingModel embeddingModel;
private final ChromaApi chromaApi;
@@ -86,6 +81,8 @@ public class ChromaVectorStore extends AbstractObservationVectorStore implements
private final ObjectMapper objectMapper;
+ private boolean initialized = false;
+
public ChromaVectorStore(EmbeddingModel embeddingModel, ChromaApi chromaApi, boolean initializeSchema) {
this(embeddingModel, chromaApi, DEFAULT_COLLECTION_NAME, initializeSchema);
}
@@ -111,6 +108,26 @@ public ChromaVectorStore(EmbeddingModel embeddingModel, ChromaApi chromaApi, Str
this.objectMapper = JsonMapper.builder().addModules(JacksonUtils.instantiateAvailableModules()).build();
}
+ private ChromaVectorStore(Builder builder) {
+ super(builder.observationRegistry, builder.customObservationConvention);
+ this.embeddingModel = builder.embeddingModel;
+ this.chromaApi = builder.chromaApi;
+ this.collectionName = builder.collectionName;
+ this.initializeSchema = builder.initializeSchema;
+ this.filterExpressionConverter = builder.filterExpressionConverter;
+ this.batchingStrategy = builder.batchingStrategy;
+ this.objectMapper = JsonMapper.builder().addModules(JacksonUtils.instantiateAvailableModules()).build();
+
+ if (builder.initializeImmediately) {
+ try {
+ afterPropertiesSet();
+ }
+ catch (Exception e) {
+ throw new IllegalStateException("Failed to initialize ChromaVectorStore", e);
+ }
+ }
+ }
+
public void setFilterExpressionConverter(FilterExpressionConverter filterExpressionConverter) {
Assert.notNull(filterExpressionConverter, "FilterExpressionConverter should not be null.");
this.filterExpressionConverter = filterExpressionConverter;
@@ -207,26 +224,95 @@ public String getCollectionId() {
@Override
public void afterPropertiesSet() throws Exception {
- var collection = this.chromaApi.getCollection(this.collectionName);
- if (collection == null) {
- if (this.initializeSchema) {
- collection = this.chromaApi
- .createCollection(new ChromaApi.CreateCollectionRequest(this.collectionName));
- }
- else {
- throw new RuntimeException("Collection " + this.collectionName
- + " doesn't exist and won't be created as the initializeSchema is set to false.");
+ if (!this.initialized) {
+ var collection = this.chromaApi.getCollection(this.collectionName);
+ if (collection == null) {
+ if (this.initializeSchema) {
+ collection = this.chromaApi
+ .createCollection(new ChromaApi.CreateCollectionRequest(this.collectionName));
+ }
+ else {
+ throw new RuntimeException("Collection " + this.collectionName
+ + " doesn't exist and won't be created as the initializeSchema is set to false.");
+ }
}
+ this.collectionId = collection.id();
+ this.initialized = true;
}
- this.collectionId = collection.id();
}
@Override
- public Builder createObservationContextBuilder(String operationName) {
+ public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) {
return VectorStoreObservationContext.builder(VectorStoreProvider.CHROMA.value(), operationName)
.withDimensions(this.embeddingModel.dimensions())
.withCollectionName(this.collectionName + ":" + this.collectionId)
.withFieldName(this.initializeSchema ? DISTANCE_FIELD_NAME : null);
}
+ public static class Builder {
+
+ private final EmbeddingModel embeddingModel;
+
+ private final ChromaApi chromaApi;
+
+ private String collectionName = DEFAULT_COLLECTION_NAME;
+
+ private boolean initializeSchema = false;
+
+ private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
+
+ private VectorStoreObservationConvention customObservationConvention = null;
+
+ private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();
+
+ private FilterExpressionConverter filterExpressionConverter = new ChromaFilterExpressionConverter();
+
+ private boolean initializeImmediately = false;
+
+ public Builder(EmbeddingModel embeddingModel, ChromaApi chromaApi) {
+ this.embeddingModel = embeddingModel;
+ this.chromaApi = chromaApi;
+ }
+
+ public Builder collectionName(String collectionName) {
+ this.collectionName = collectionName;
+ return this;
+ }
+
+ public Builder initializeSchema(boolean initializeSchema) {
+ this.initializeSchema = initializeSchema;
+ return this;
+ }
+
+ public Builder observationRegistry(ObservationRegistry observationRegistry) {
+ this.observationRegistry = observationRegistry;
+ return this;
+ }
+
+ public Builder customObservationConvention(VectorStoreObservationConvention convention) {
+ this.customObservationConvention = convention;
+ return this;
+ }
+
+ public Builder batchingStrategy(BatchingStrategy batchingStrategy) {
+ this.batchingStrategy = batchingStrategy;
+ return this;
+ }
+
+ public Builder filterExpressionConverter(FilterExpressionConverter converter) {
+ this.filterExpressionConverter = converter;
+ return this;
+ }
+
+ public Builder initializeImmediately(boolean initialize) {
+ this.initializeImmediately = initialize;
+ return this;
+ }
+
+ public ChromaVectorStore build() {
+ return new ChromaVectorStore(this);
+ }
+
+ }
+
}
diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java
index bfc84340294..0b01ae35486 100644
--- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java
+++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java
@@ -16,6 +16,7 @@
package org.springframework.ai.chroma;
+import java.util.Collections;
import java.util.List;
import java.util.Map;
@@ -30,17 +31,24 @@
import org.springframework.ai.chroma.ChromaApi.Collection;
import org.springframework.ai.chroma.ChromaApi.GetEmbeddingsRequest;
import org.springframework.ai.chroma.ChromaApi.QueryRequest;
+import org.springframework.ai.document.Document;
+import org.springframework.ai.embedding.EmbeddingModel;
+import org.springframework.ai.transformers.TransformersEmbeddingModel;
+import org.springframework.ai.vectorstore.ChromaVectorStore;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;
import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.AssertionsForClassTypes.assertThatNoException;
+import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
/**
* @author Christian Tzolov
* @author EddĂș MelĂ©ndez
* @author Thomas Vitale
+ * @author Soby Chacko
*/
@SpringBootTest
@Testcontainers
@@ -50,17 +58,20 @@ public class ChromaApiIT {
static ChromaDBContainer chromaContainer = new ChromaDBContainer(ChromaImage.DEFAULT_IMAGE);
@Autowired
- ChromaApi chroma;
+ ChromaApi chromaApi;
+
+ @Autowired
+ EmbeddingModel embeddingModel;
@BeforeEach
public void beforeEach() {
- this.chroma.listCollections().stream().forEach(c -> this.chroma.deleteCollection(c.name()));
+ this.chromaApi.listCollections().stream().forEach(c -> this.chromaApi.deleteCollection(c.name()));
}
@Test
public void testClientWithMetadata() {
Map metadata = Map.of("hnsw:space", "cosine", "hnsw:M", 5);
- var newCollection = this.chroma
+ var newCollection = this.chromaApi
.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection", metadata));
assertThat(newCollection).isNotNull();
assertThat(newCollection.name()).isEqualTo("TestCollection");
@@ -68,44 +79,44 @@ public void testClientWithMetadata() {
@Test
public void testClient() {
- var newCollection = this.chroma.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection"));
+ var newCollection = this.chromaApi.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection"));
assertThat(newCollection).isNotNull();
assertThat(newCollection.name()).isEqualTo("TestCollection");
- var getCollection = this.chroma.getCollection("TestCollection");
+ var getCollection = this.chromaApi.getCollection("TestCollection");
assertThat(getCollection).isNotNull();
assertThat(getCollection.name()).isEqualTo("TestCollection");
assertThat(getCollection.id()).isEqualTo(newCollection.id());
- List collections = this.chroma.listCollections();
+ List collections = this.chromaApi.listCollections();
assertThat(collections).hasSize(1);
assertThat(collections.get(0).id()).isEqualTo(newCollection.id());
- this.chroma.deleteCollection(newCollection.name());
- assertThat(this.chroma.listCollections()).hasSize(0);
+ this.chromaApi.deleteCollection(newCollection.name());
+ assertThat(this.chromaApi.listCollections()).hasSize(0);
}
@Test
public void testCollection() {
- var newCollection = this.chroma.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection"));
- assertThat(this.chroma.countEmbeddings(newCollection.id())).isEqualTo(0);
+ var newCollection = this.chromaApi.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection"));
+ assertThat(this.chromaApi.countEmbeddings(newCollection.id())).isEqualTo(0);
var addEmbeddingRequest = new AddEmbeddingsRequest(List.of("id1", "id2"),
List.of(new float[] { 1f, 1f, 1f }, new float[] { 2f, 2f, 2f }),
List.of(Map.of(), Map.of("key1", "value1", "key2", true, "key3", 23.4)),
List.of("Hello World", "Big World"));
- this.chroma.upsertEmbeddings(newCollection.id(), addEmbeddingRequest);
+ this.chromaApi.upsertEmbeddings(newCollection.id(), addEmbeddingRequest);
var addEmbeddingRequest2 = new AddEmbeddingsRequest("id3", new float[] { 3f, 3f, 3f },
Map.of("key1", "value1", "key2", true, "key3", 23.4), "Big World");
- this.chroma.upsertEmbeddings(newCollection.id(), addEmbeddingRequest2);
+ this.chromaApi.upsertEmbeddings(newCollection.id(), addEmbeddingRequest2);
- assertThat(this.chroma.countEmbeddings(newCollection.id())).isEqualTo(3);
+ assertThat(this.chromaApi.countEmbeddings(newCollection.id())).isEqualTo(3);
- var queryResult = this.chroma.queryCollection(newCollection.id(),
- new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chroma.where("""
+ var queryResult = this.chromaApi.queryCollection(newCollection.id(),
+ new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chromaApi.where("""
{
"key2" : { "$eq": true }
}
@@ -114,14 +125,14 @@ public void testCollection() {
assertThat(queryResult.ids().get(0)).containsExactlyInAnyOrder("id2", "id3");
// Update existing embedding.
- this.chroma.upsertEmbeddings(newCollection.id(), new AddEmbeddingsRequest("id3", new float[] { 6f, 6f, 6f },
+ this.chromaApi.upsertEmbeddings(newCollection.id(), new AddEmbeddingsRequest("id3", new float[] { 6f, 6f, 6f },
Map.of("key1", "value2", "key2", false, "key4", 23.4), "Small World"));
- var result = this.chroma.getEmbeddings(newCollection.id(), new GetEmbeddingsRequest(List.of("id2")));
+ var result = this.chromaApi.getEmbeddings(newCollection.id(), new GetEmbeddingsRequest(List.of("id2")));
assertThat(result.ids().get(0)).isEqualTo("id2");
- queryResult = this.chroma.queryCollection(newCollection.id(),
- new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chroma.where("""
+ queryResult = this.chromaApi.queryCollection(newCollection.id(),
+ new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chromaApi.where("""
{
"key2" : { "$eq": true }
}
@@ -133,7 +144,7 @@ public void testCollection() {
@Test
public void testQueryWhere() {
- var collection = this.chroma.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection"));
+ var collection = this.chromaApi.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection"));
var add1 = new AddEmbeddingsRequest("id1", new float[] { 1f, 1f, 1f },
Map.of("country", "BG", "active", true, "price", 23.4, "year", 2020),
@@ -146,24 +157,25 @@ public void testQueryWhere() {
Map.of("country", "BG", "active", false, "price", 40.1, "year", 2023),
"The World is Big and Salvation Lurks Around the Corner");
- this.chroma.upsertEmbeddings(collection.id(), add1);
- this.chroma.upsertEmbeddings(collection.id(), add2);
- this.chroma.upsertEmbeddings(collection.id(), add3);
+ this.chromaApi.upsertEmbeddings(collection.id(), add1);
+ this.chromaApi.upsertEmbeddings(collection.id(), add2);
+ this.chromaApi.upsertEmbeddings(collection.id(), add3);
- assertThat(this.chroma.countEmbeddings(collection.id())).isEqualTo(3);
+ assertThat(this.chromaApi.countEmbeddings(collection.id())).isEqualTo(3);
- var queryResult = this.chroma.queryCollection(collection.id(), new QueryRequest(new float[] { 1f, 1f, 1f }, 3));
+ var queryResult = this.chromaApi.queryCollection(collection.id(),
+ new QueryRequest(new float[] { 1f, 1f, 1f }, 3));
assertThat(queryResult.ids().get(0)).hasSize(3);
assertThat(queryResult.ids().get(0)).containsExactlyInAnyOrder("id1", "id2", "id3");
- var chromaEmbeddings = this.chroma.toEmbeddingResponseList(queryResult);
+ var chromaEmbeddings = this.chromaApi.toEmbeddingResponseList(queryResult);
assertThat(chromaEmbeddings).hasSize(3);
assertThat(chromaEmbeddings).hasSize(3);
- queryResult = this.chroma.queryCollection(collection.id(),
- new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chroma.where("""
+ queryResult = this.chromaApi.queryCollection(collection.id(),
+ new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chromaApi.where("""
{
"$and" : [
{"country" : { "$eq": "BG"}},
@@ -174,8 +186,8 @@ public void testQueryWhere() {
assertThat(queryResult.ids().get(0)).hasSize(2);
assertThat(queryResult.ids().get(0)).containsExactlyInAnyOrder("id1", "id3");
- queryResult = this.chroma.queryCollection(collection.id(),
- new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chroma.where("""
+ queryResult = this.chromaApi.queryCollection(collection.id(),
+ new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chromaApi.where("""
{
"$and" : [
{"country" : { "$eq": "BG"}},
@@ -188,6 +200,53 @@ public void testQueryWhere() {
assertThat(queryResult.ids().get(0)).containsExactlyInAnyOrder("id1");
}
+ @Test
+ void shouldUseExistingCollectionWhenSchemaInitializationDisabled() { // initializeSchema
+ // is false by
+ // default.
+ var collection = this.chromaApi.createCollection(new ChromaApi.CreateCollectionRequest("test-collection"));
+ assertThat(collection).isNotNull();
+ assertThat(collection.name()).isEqualTo("test-collection");
+
+ ChromaVectorStore store = new ChromaVectorStore.Builder(this.embeddingModel, this.chromaApi)
+ .collectionName("test-collection")
+ .initializeImmediately(true)
+ .build();
+
+ Document document = new Document("test content");
+ assertThatNoException().isThrownBy(() -> store.add(Collections.singletonList(document)));
+ }
+
+ @Test
+ void shouldCreateNewCollectionWhenSchemaInitializationEnabled() {
+ ChromaVectorStore store = new ChromaVectorStore.Builder(this.embeddingModel, this.chromaApi)
+ .collectionName("new-collection")
+ .initializeSchema(true)
+ .initializeImmediately(true)
+ .build();
+
+ var collection = this.chromaApi.getCollection("new-collection");
+ assertThat(collection).isNotNull();
+ assertThat(collection.name()).isEqualTo("new-collection");
+
+ Document document = new Document("test content");
+ assertThatNoException().isThrownBy(() -> store.add(Collections.singletonList(document)));
+ }
+
+ @Test
+ void shouldFailWhenCollectionDoesNotExist() {
+ assertThatThrownBy(
+ () -> new ChromaVectorStore.Builder(this.embeddingModel, this.chromaApi).collectionName("non-existent")
+ .initializeSchema(false)
+ .initializeImmediately(true)
+ .build())
+ .isInstanceOf(IllegalStateException.class)
+ .hasMessage("Failed to initialize ChromaVectorStore")
+ .hasCauseInstanceOf(RuntimeException.class)
+ .hasRootCauseMessage(
+ "Collection non-existent doesn't exist and won't be created as the initializeSchema is set to false.");
+ }
+
@SpringBootConfiguration
public static class Config {
@@ -196,6 +255,11 @@ public ChromaApi chromaApi() {
return new ChromaApi(chromaContainer.getEndpoint());
}
+ @Bean
+ public EmbeddingModel embeddingModel() {
+ return new TransformersEmbeddingModel();
+ }
+
}
}