1616
1717package org .springframework .ai .vectorstore ;
1818
19- import com . azure . cosmos .* ;
20- import com . azure . cosmos . implementation . guava25 . collect . ImmutableList ;
21- import com . azure . cosmos . models .* ;
22- import com . azure . cosmos . util .CosmosPagedFlux ;
23- import com . fasterxml . jackson . databind . JsonNode ;
24- import com . fasterxml . jackson . databind . ObjectMapper ;
25- import com . fasterxml . jackson . databind . node . ObjectNode ;
26- import io . micrometer . observation . ObservationRegistry ;
19+ import java . util . ArrayList ;
20+ import java . util . Collections ;
21+ import java . util . HashMap ;
22+ import java . util .List ;
23+ import java . util . Optional ;
24+ import java . util . stream . Collectors ;
25+ import java . util . stream . IntStream ;
26+
2727import org .apache .commons .lang3 .tuple .ImmutablePair ;
28- import org .apache .commons .lang3 .tuple .Pair ;
2928import org .slf4j .Logger ;
3029import org .slf4j .LoggerFactory ;
3130import org .springframework .ai .document .Document ;
3837import org .springframework .ai .vectorstore .observation .AbstractObservationVectorStore ;
3938import org .springframework .ai .vectorstore .observation .VectorStoreObservationContext ;
4039import org .springframework .ai .vectorstore .observation .VectorStoreObservationConvention ;
40+
41+ import com .azure .cosmos .CosmosAsyncClient ;
42+ import com .azure .cosmos .CosmosAsyncContainer ;
43+ import com .azure .cosmos .CosmosAsyncDatabase ;
44+ import com .azure .cosmos .implementation .guava25 .collect .ImmutableList ;
45+ import com .azure .cosmos .models .CosmosBulkOperations ;
46+ import com .azure .cosmos .models .CosmosContainerProperties ;
47+ import com .azure .cosmos .models .CosmosItemOperation ;
48+ import com .azure .cosmos .models .CosmosQueryRequestOptions ;
49+ import com .azure .cosmos .models .CosmosVectorDataType ;
50+ import com .azure .cosmos .models .CosmosVectorDistanceFunction ;
51+ import com .azure .cosmos .models .CosmosVectorEmbedding ;
52+ import com .azure .cosmos .models .CosmosVectorEmbeddingPolicy ;
53+ import com .azure .cosmos .models .CosmosVectorIndexSpec ;
54+ import com .azure .cosmos .models .CosmosVectorIndexType ;
55+ import com .azure .cosmos .models .ExcludedPath ;
56+ import com .azure .cosmos .models .IncludedPath ;
57+ import com .azure .cosmos .models .IndexingMode ;
58+ import com .azure .cosmos .models .IndexingPolicy ;
59+ import com .azure .cosmos .models .PartitionKey ;
60+ import com .azure .cosmos .models .PartitionKeyDefinition ;
61+ import com .azure .cosmos .models .PartitionKind ;
62+ import com .azure .cosmos .models .SqlParameter ;
63+ import com .azure .cosmos .models .SqlQuerySpec ;
64+ import com .azure .cosmos .models .ThroughputProperties ;
65+ import com .azure .cosmos .util .CosmosPagedFlux ;
66+ import com .fasterxml .jackson .databind .JsonNode ;
67+ import com .fasterxml .jackson .databind .ObjectMapper ;
68+ import com .fasterxml .jackson .databind .node .ObjectNode ;
69+
70+ import io .micrometer .observation .ObservationRegistry ;
4171import reactor .core .publisher .Flux ;
42- import java .util .*;
43- import java .util .stream .Collectors ;
44- import java .util .stream .IntStream ;
4572
4673/**
4774 * @author Theo van Kraay
@@ -79,38 +106,38 @@ public CosmosDBVectorStore(ObservationRegistry observationRegistry,
79106 cosmosClient .createDatabaseIfNotExists (properties .getDatabaseName ()).block ();
80107
81108 initializeContainer (properties .getContainerName (), properties .getDatabaseName (),
82- properties .getVectorStoreThoughput (), properties .getVectorDimensions (),
109+ properties .getVectorStoreThroughput (), properties .getVectorDimensions (),
83110 properties .getPartitionKeyPath ());
84111
85112 this .embeddingModel = embeddingModel ;
86113 }
87114
88- private void initializeContainer (String containerName , String databaseName , int vectorStoreThoughput ,
115+ private void initializeContainer (String containerName , String databaseName , int vectorStoreThroughput ,
89116 long vectorDimensions , String partitionKeyPath ) {
90117
91118 // Set defaults if not provided
92- if (vectorStoreThoughput == 0 ) {
93- vectorStoreThoughput = 400 ;
119+ if (vectorStoreThroughput == 0 ) {
120+ vectorStoreThroughput = 400 ;
94121 }
95122 if (partitionKeyPath == null ) {
96123 partitionKeyPath = "/id" ;
97124 }
98125
99126 // handle hierarchical partition key
100- PartitionKeyDefinition subpartitionKeyDefinition = new PartitionKeyDefinition ();
101- List <String > pathsfromCommaSeparatedList = new ArrayList <String >();
102- String [] subpartitionKeyPaths = partitionKeyPath .split ("," );
103- Collections .addAll (pathsfromCommaSeparatedList , subpartitionKeyPaths );
104- if (subpartitionKeyPaths .length > 1 ) {
105- subpartitionKeyDefinition .setPaths (pathsfromCommaSeparatedList );
106- subpartitionKeyDefinition .setKind (PartitionKind .MULTI_HASH );
127+ PartitionKeyDefinition subPartitionKeyDefinition = new PartitionKeyDefinition ();
128+ List <String > pathsFromCommaSeparatedList = new ArrayList <String >();
129+ String [] subPartitionKeyPaths = partitionKeyPath .split ("," );
130+ Collections .addAll (pathsFromCommaSeparatedList , subPartitionKeyPaths );
131+ if (subPartitionKeyPaths .length > 1 ) {
132+ subPartitionKeyDefinition .setPaths (pathsFromCommaSeparatedList );
133+ subPartitionKeyDefinition .setKind (PartitionKind .MULTI_HASH );
107134 }
108135 else {
109- subpartitionKeyDefinition .setPaths (Collections .singletonList (partitionKeyPath ));
110- subpartitionKeyDefinition .setKind (PartitionKind .HASH );
136+ subPartitionKeyDefinition .setPaths (Collections .singletonList (partitionKeyPath ));
137+ subPartitionKeyDefinition .setKind (PartitionKind .HASH );
111138 }
112139 CosmosContainerProperties collectionDefinition = new CosmosContainerProperties (containerName ,
113- subpartitionKeyDefinition );
140+ subPartitionKeyDefinition );
114141 // Set vector embedding policy
115142 CosmosVectorEmbeddingPolicy embeddingPolicy = new CosmosVectorEmbeddingPolicy ();
116143 CosmosVectorEmbedding embedding = new CosmosVectorEmbedding ();
@@ -135,16 +162,16 @@ private void initializeContainer(String containerName, String databaseName, int
135162 indexingPolicy .setVectorIndexes (List .of (cosmosVectorIndexSpec ));
136163 collectionDefinition .setIndexingPolicy (indexingPolicy );
137164
138- ThroughputProperties throughputProperties = ThroughputProperties .createManualThroughput (vectorStoreThoughput );
139- CosmosAsyncDatabase cosmosAsyncDatabase = cosmosClient .getDatabase (databaseName );
165+ ThroughputProperties throughputProperties = ThroughputProperties .createManualThroughput (vectorStoreThroughput );
166+ CosmosAsyncDatabase cosmosAsyncDatabase = this . cosmosClient .getDatabase (databaseName );
140167 cosmosAsyncDatabase .createContainerIfNotExists (collectionDefinition , throughputProperties ).block ();
141168 this .container = cosmosAsyncDatabase .getContainer (containerName );
142169 }
143170
144171 @ Override
145172 public void close () {
146- if (cosmosClient != null ) {
147- cosmosClient .close ();
173+ if (this . cosmosClient != null ) {
174+ this . cosmosClient .close ();
148175 logger .info ("Cosmos DB client closed successfully." );
149176 }
150177 }
@@ -192,7 +219,7 @@ public void doAdd(List<Document> documents) {
192219 .map (ImmutablePair ::getValue )
193220 .collect (Collectors .toList ());
194221
195- container .executeBulkOperations (Flux .fromIterable (itemOperations )).doOnNext (response -> {
222+ this . container .executeBulkOperations (Flux .fromIterable (itemOperations )).doOnNext (response -> {
196223 if (response != null && response .getResponse () != null ) {
197224 int statusCode = response .getResponse ().getStatusCode ();
198225 if (statusCode == 409 ) {
@@ -236,7 +263,7 @@ public Optional<Boolean> doDelete(List<String> idList) {
236263
237264 // Execute bulk delete operations synchronously by using blockLast() on the
238265 // Flux
239- container .executeBulkOperations (Flux .fromIterable (itemOperations ))
266+ this . container .executeBulkOperations (Flux .fromIterable (itemOperations ))
240267 .doOnNext (response -> logger .info ("Document deleted with status: {}" ,
241268 response .getResponse ().getStatusCode ()))
242269 .doOnError (error -> logger .error ("Error deleting document: {}" , error .getMessage ()))
@@ -279,9 +306,11 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
279306 Filter .Expression filterExpression = request .getFilterExpression ();
280307 if (filterExpression != null ) {
281308 CosmosDBFilterExpressionConverter filterExpressionConverter = new CosmosDBFilterExpressionConverter (
282- properties .getMetadataFieldsList ()); // Use the expression directly as
283- // it handles the "metadata"
284- // fields internally
309+ this .properties .getMetadataFieldsList ()); // Use the expression
310+ // directly as
311+ // it handles the
312+ // "metadata"
313+ // fields internally
285314 String filterQuery = filterExpressionConverter .convertExpression (filterExpression );
286315 queryBuilder .append (" AND " ).append (filterQuery );
287316 }
@@ -297,7 +326,7 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
297326 SqlQuerySpec sqlQuerySpec = new SqlQuerySpec (query , parameters );
298327 CosmosQueryRequestOptions options = new CosmosQueryRequestOptions ();
299328
300- CosmosPagedFlux <JsonNode > pagedFlux = container .queryItems (sqlQuerySpec , options , JsonNode .class );
329+ CosmosPagedFlux <JsonNode > pagedFlux = this . container .queryItems (sqlQuerySpec , options , JsonNode .class );
301330
302331 logger .info ("Executing similarity search query: {}" , query );
303332 try {
@@ -322,9 +351,9 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
322351 @ Override
323352 public VectorStoreObservationContext .Builder createObservationContextBuilder (String operationName ) {
324353 return VectorStoreObservationContext .builder (VectorStoreProvider .COSMOSDB .value (), operationName )
325- .withCollectionName (container .getId ())
354+ .withCollectionName (this . container .getId ())
326355 .withDimensions (this .embeddingModel .dimensions ())
327- .withNamespace (container .getDatabase ().getId ())
356+ .withNamespace (this . container .getDatabase ().getId ())
328357 .withSimilarityMetric ("cosine" );
329358 }
330359
0 commit comments