diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/pom.xml b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/pom.xml new file mode 100644 index 00000000000..737dfbb162c --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/pom.xml @@ -0,0 +1,102 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai-parent + 2.0.0-SNAPSHOT + ../../../pom.xml + + spring-ai-autoconfigure-model-cohere + jar + Spring AI Cohere Auto Configuration + Spring AI Cohere Auto Configuration + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + + org.springframework.ai + spring-ai-cohere + ${project.parent.version} + true + + + + org.springframework.ai + spring-ai-autoconfigure-model-tool + ${project.parent.version} + + + + org.springframework.ai + spring-ai-autoconfigure-retry + ${project.parent.version} + + + + org.springframework.ai + spring-ai-autoconfigure-model-chat-observation + ${project.parent.version} + + + + org.springframework.boot + spring-boot-starter + true + + + + org.springframework.boot + spring-boot-starter-webclient + true + + + org.springframework.boot + spring-boot-starter-restclient + true + + + + org.springframework.boot + spring-boot-configuration-processor + true + + + + org.springframework.boot + spring-boot-autoconfigure-processor + true + + + + + org.springframework.ai + spring-ai-test + ${project.parent.version} + test + + + + org.springframework.boot + spring-boot-starter-test + test + + + + org.mockito + mockito-core + test + + + + diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereChatAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereChatAutoConfiguration.java new file mode 100644 index 00000000000..4fee095d6c9 --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereChatAutoConfiguration.java @@ -0,0 +1,110 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.autoconfigure; + +import io.micrometer.observation.ObservationRegistry; + +import org.springframework.ai.chat.observation.ChatModelObservationConvention; +import org.springframework.ai.cohere.api.CohereApi; +import org.springframework.ai.cohere.chat.CohereChatModel; +import org.springframework.ai.model.SpringAIModelProperties; +import org.springframework.ai.model.SpringAIModels; +import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; +import org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration; +import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration; +import org.springframework.beans.factory.ObjectProvider; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.ImportAutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; +import org.springframework.boot.webclient.autoconfigure.WebClientAutoConfiguration; +import org.springframework.context.annotation.Bean; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; + +/** + * Chat {@link AutoConfiguration Auto-configuration} for Cohere. + * + * @author Ricken Bazolo + */ +@AutoConfiguration(after = { RestClientAutoConfiguration.class, WebClientAutoConfiguration.class, + SpringAiRetryAutoConfiguration.class, ToolCallingAutoConfiguration.class }) +@EnableConfigurationProperties({ CohereCommonProperties.class, CohereChatProperties.class }) +@ConditionalOnProperty(name = SpringAIModelProperties.CHAT_MODEL, havingValue = SpringAIModels.COHERE, + matchIfMissing = true) +@ConditionalOnClass(CohereApi.class) +@ImportAutoConfiguration(classes = { SpringAiRetryAutoConfiguration.class, RestClientAutoConfiguration.class, + ToolCallingAutoConfiguration.class }) +public class CohereChatAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + public CohereChatModel chereChatModel(CohereCommonProperties commonProperties, CohereChatProperties chatProperties, + ObjectProvider restClientBuilderProvider, + ObjectProvider webClientBuilderProvider, ToolCallingManager toolCallingManager, + RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler, + ObjectProvider observationRegistry, + ObjectProvider observationConvention, + ObjectProvider cohereToolExecutionEligibilityPredicate) { + var cohereApi = cohereApi(chatProperties.getApiKey(), commonProperties.getApiKey(), chatProperties.getBaseUrl(), + commonProperties.getBaseUrl(), restClientBuilderProvider.getIfAvailable(RestClient::builder), + webClientBuilderProvider.getIfAvailable(WebClient::builder), responseErrorHandler); + + var chatModel = CohereChatModel.builder() + .cohereApi(cohereApi) + .defaultOptions(chatProperties.getOptions()) + .toolCallingManager(toolCallingManager) + .toolExecutionEligibilityPredicate( + cohereToolExecutionEligibilityPredicate.getIfUnique(DefaultToolExecutionEligibilityPredicate::new)) + .retryTemplate(new RetryTemplate()) + .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) + .build(); + + observationConvention.ifAvailable(chatModel::setObservationConvention); + + return chatModel; + } + + private CohereApi cohereApi(String apiKey, String commonApiKey, String baseUrl, String commonBaseUrl, + RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, + ResponseErrorHandler responseErrorHandler) { + + var resolvedApiKey = StringUtils.hasText(apiKey) ? apiKey : commonApiKey; + var resoledBaseUrl = StringUtils.hasText(baseUrl) ? baseUrl : commonBaseUrl; + + Assert.hasText(resolvedApiKey, "Cohere API key must be set"); + Assert.hasText(resoledBaseUrl, "Cohere base URL must be set"); + + return CohereApi.builder() + .baseUrl(resoledBaseUrl) + .apiKey(resolvedApiKey) + .restClientBuilder(restClientBuilder) + .webClientBuilder(webClientBuilder) + .responseErrorHandler(responseErrorHandler) + .build(); + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereChatProperties.java b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereChatProperties.java new file mode 100644 index 00000000000..49ef712477f --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereChatProperties.java @@ -0,0 +1,62 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.autoconfigure; + +import org.springframework.ai.cohere.api.CohereApi; +import org.springframework.ai.cohere.chat.CohereChatOptions; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +/** + * Configuration properties for Cohere chat. + * + * @author Ricken Bazolo + */ +@ConfigurationProperties(CohereChatProperties.CONFIG_PREFIX) +public class CohereChatProperties extends CohereParentProperties { + + public static final String CONFIG_PREFIX = "spring.ai.cohere.chat"; + + public static final String DEFAULT_CHAT_MODEL = CohereApi.ChatModel.COMMAND_A_R7B.getValue(); + + private static final Double DEFAULT_TEMPERATURE = 0.3; + + private static final Double DEFAULT_TOP_P = 1.0; + + @NestedConfigurationProperty + private CohereChatOptions options = CohereChatOptions.builder() + .model(DEFAULT_CHAT_MODEL) + .temperature(DEFAULT_TEMPERATURE) + .topP(DEFAULT_TOP_P) + .presencePenalty(0.0) + .frequencyPenalty(0.0) + .logprobs(false) + .build(); + + public CohereChatProperties() { + super.setBaseUrl(CohereCommonProperties.DEFAULT_BASE_URL); + } + + public CohereChatOptions getOptions() { + return this.options; + } + + public void setOptions(CohereChatOptions options) { + this.options = options; + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereCommonProperties.java b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereCommonProperties.java new file mode 100644 index 00000000000..db672b8328b --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereCommonProperties.java @@ -0,0 +1,37 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.autoconfigure; + +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * Common properties for Cohere. + * + * @author Ricken Bazolo + */ +@ConfigurationProperties(CohereCommonProperties.CONFIG_PREFIX) +public class CohereCommonProperties extends CohereParentProperties { + + public static final String CONFIG_PREFIX = "spring.ai.cohere"; + + public static final String DEFAULT_BASE_URL = "https://api.cohere.com"; + + public CohereCommonProperties() { + super.setBaseUrl(DEFAULT_BASE_URL); + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereEmbeddingAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereEmbeddingAutoConfiguration.java new file mode 100644 index 00000000000..127610a2288 --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereEmbeddingAutoConfiguration.java @@ -0,0 +1,95 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.autoconfigure; + +import io.micrometer.observation.ObservationRegistry; + +import org.springframework.ai.cohere.api.CohereApi; +import org.springframework.ai.cohere.embedding.CohereEmbeddingModel; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; +import org.springframework.ai.model.SpringAIModelProperties; +import org.springframework.ai.model.SpringAIModels; +import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration; +import org.springframework.beans.factory.ObjectProvider; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; +import org.springframework.context.annotation.Bean; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; + +/** + * Embedding {@link AutoConfiguration Auto-configuration} for Cohere + * + * @author Ricken Bazolo + */ +@AutoConfiguration(after = { RestClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class }) +@EnableConfigurationProperties({ CohereCommonProperties.class, CohereEmbeddingProperties.class }) +@ConditionalOnClass(CohereApi.class) +@ConditionalOnProperty(name = SpringAIModelProperties.EMBEDDING_MODEL, havingValue = SpringAIModels.COHERE, + matchIfMissing = true) +public class CohereEmbeddingAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + public CohereEmbeddingModel mistralAiEmbeddingModel(CohereCommonProperties commonProperties, + CohereEmbeddingProperties embeddingProperties, ObjectProvider restClientBuilderProvider, + RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler, + ObjectProvider observationRegistry, + ObjectProvider observationConvention) { + + var cohereApi = cohereApi(embeddingProperties.getApiKey(), commonProperties.getApiKey(), + embeddingProperties.getBaseUrl(), commonProperties.getBaseUrl(), + restClientBuilderProvider.getIfAvailable(RestClient::builder), responseErrorHandler); + + var embeddingModel = CohereEmbeddingModel.builder() + .cohereApi(cohereApi) + .metadataMode(embeddingProperties.getMetadataMode()) + .options(embeddingProperties.getOptions()) + .retryTemplate(retryTemplate) + .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) + .build(); + + observationConvention.ifAvailable(embeddingModel::setObservationConvention); + + return embeddingModel; + } + + private CohereApi cohereApi(String apiKey, String commonApiKey, String baseUrl, String commonBaseUrl, + RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) { + + var resolvedApiKey = StringUtils.hasText(apiKey) ? apiKey : commonApiKey; + var resoledBaseUrl = StringUtils.hasText(baseUrl) ? baseUrl : commonBaseUrl; + + Assert.hasText(resolvedApiKey, "Cohere API key must be set"); + Assert.hasText(resoledBaseUrl, "Cohere base URL must be set"); + + return CohereApi.builder() + .baseUrl(resoledBaseUrl) + .apiKey(resolvedApiKey) + .restClientBuilder(restClientBuilder) + .responseErrorHandler(responseErrorHandler) + .build(); + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereEmbeddingProperties.java b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereEmbeddingProperties.java new file mode 100644 index 00000000000..75acd6b3f8c --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereEmbeddingProperties.java @@ -0,0 +1,66 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.autoconfigure; + +import java.util.List; + +import org.springframework.ai.cohere.api.CohereApi.EmbeddingModel; +import org.springframework.ai.cohere.api.CohereApi.EmbeddingType; +import org.springframework.ai.cohere.embedding.CohereEmbeddingOptions; +import org.springframework.ai.document.MetadataMode; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +/** + * Configuration properties for Cohere embedding model. + * + * @author Ricken Bazolo + */ +@ConfigurationProperties(CohereEmbeddingProperties.CONFIG_PREFIX) +public class CohereEmbeddingProperties extends CohereParentProperties { + + public static final String CONFIG_PREFIX = "spring.ai.cohere.embedding"; + + public static final String DEFAULT_EMBEDDING_MODEL = EmbeddingModel.EMBED_V4.getValue(); + + public static final String DEFAULT_ENCODING_FORMAT = EmbeddingType.FLOAT.name(); + + public MetadataMode metadataMode = MetadataMode.EMBED; + + @NestedConfigurationProperty + private final CohereEmbeddingOptions options = CohereEmbeddingOptions.builder() + .model(DEFAULT_EMBEDDING_MODEL) + .embeddingTypes(List.of(EmbeddingType.valueOf(DEFAULT_ENCODING_FORMAT))) + .build(); + + public CohereEmbeddingProperties() { + super.setBaseUrl(CohereCommonProperties.DEFAULT_BASE_URL); + } + + public CohereEmbeddingOptions getOptions() { + return this.options; + } + + public MetadataMode getMetadataMode() { + return this.metadataMode; + } + + public void setMetadataMode(MetadataMode metadataMode) { + this.metadataMode = metadataMode; + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereMultimodalEmbeddingAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereMultimodalEmbeddingAutoConfiguration.java new file mode 100644 index 00000000000..72901f91c1e --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereMultimodalEmbeddingAutoConfiguration.java @@ -0,0 +1,88 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.autoconfigure; + +import io.micrometer.observation.ObservationRegistry; + +import org.springframework.ai.cohere.api.CohereApi; +import org.springframework.ai.cohere.embedding.CohereMultimodalEmbeddingModel; +import org.springframework.ai.model.SpringAIModelProperties; +import org.springframework.ai.model.SpringAIModels; +import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration; +import org.springframework.beans.factory.ObjectProvider; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; +import org.springframework.context.annotation.Bean; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; + +/** + * Multimodal Embedding {@link AutoConfiguration Auto-configuration} for Cohere + * + * @author Ricken Bazolo + */ +@AutoConfiguration(after = { RestClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class }) +@EnableConfigurationProperties({ CohereCommonProperties.class, CohereMultimodalEmbeddingProperties.class }) +@ConditionalOnClass({ CohereApi.class, CohereMultimodalEmbeddingModel.class }) +@ConditionalOnProperty(name = SpringAIModelProperties.MULTI_MODAL_EMBEDDING_MODEL, havingValue = SpringAIModels.COHERE, + matchIfMissing = true) +public class CohereMultimodalEmbeddingAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + public CohereMultimodalEmbeddingModel cohereMultimodalEmbeddingModel(CohereCommonProperties commonProperties, + CohereMultimodalEmbeddingProperties embeddingProperties, + ObjectProvider restClientBuilderProvider, RetryTemplate retryTemplate, + ResponseErrorHandler responseErrorHandler, ObjectProvider observationRegistry) { + + var cohereApi = cohereApi(embeddingProperties.getApiKey(), commonProperties.getApiKey(), + embeddingProperties.getBaseUrl(), commonProperties.getBaseUrl(), + restClientBuilderProvider.getIfAvailable(RestClient::builder), responseErrorHandler); + + return CohereMultimodalEmbeddingModel.builder() + .cohereApi(cohereApi) + .options(embeddingProperties.getOptions()) + .retryTemplate(retryTemplate) + .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) + .build(); + } + + private CohereApi cohereApi(String apiKey, String commonApiKey, String baseUrl, String commonBaseUrl, + RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) { + + var resolvedApiKey = StringUtils.hasText(apiKey) ? apiKey : commonApiKey; + var resoledBaseUrl = StringUtils.hasText(baseUrl) ? baseUrl : commonBaseUrl; + + Assert.hasText(resolvedApiKey, "Cohere API key must be set"); + Assert.hasText(resoledBaseUrl, "Cohere base URL must be set"); + + return CohereApi.builder() + .baseUrl(resoledBaseUrl) + .apiKey(resolvedApiKey) + .restClientBuilder(restClientBuilder) + .responseErrorHandler(responseErrorHandler) + .build(); + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereMultimodalEmbeddingProperties.java b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereMultimodalEmbeddingProperties.java new file mode 100644 index 00000000000..8caa70cdcc8 --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereMultimodalEmbeddingProperties.java @@ -0,0 +1,55 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.autoconfigure; + +import java.util.List; + +import org.springframework.ai.cohere.api.CohereApi.EmbeddingModel; +import org.springframework.ai.cohere.api.CohereApi.EmbeddingType; +import org.springframework.ai.cohere.embedding.CohereMultimodalEmbeddingOptions; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +/** + * Configuration properties for Cohere multimodal embedding model. + * + * @author Ricken Bazolo + */ +@ConfigurationProperties(CohereMultimodalEmbeddingProperties.CONFIG_PREFIX) +public class CohereMultimodalEmbeddingProperties extends CohereParentProperties { + + public static final String CONFIG_PREFIX = "spring.ai.cohere.embedding.multimodal"; + + public static final String DEFAULT_EMBEDDING_MODEL = EmbeddingModel.EMBED_V4.getValue(); + + public static final String DEFAULT_ENCODING_FORMAT = EmbeddingType.FLOAT.name(); + + @NestedConfigurationProperty + private final CohereMultimodalEmbeddingOptions options = CohereMultimodalEmbeddingOptions.builder() + .model(DEFAULT_EMBEDDING_MODEL) + .embeddingTypes(List.of(EmbeddingType.valueOf(DEFAULT_ENCODING_FORMAT))) + .build(); + + public CohereMultimodalEmbeddingProperties() { + super.setBaseUrl(CohereCommonProperties.DEFAULT_BASE_URL); + } + + public CohereMultimodalEmbeddingOptions getOptions() { + return this.options; + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereParentProperties.java b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereParentProperties.java new file mode 100644 index 00000000000..0561690d7bb --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereParentProperties.java @@ -0,0 +1,46 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.autoconfigure; + +/** + * Parent properties for Cohere. + * + * @author Ricken Bazolo + */ +public class CohereParentProperties { + + private String apiKey; + + private String baseUrl; + + public String getApiKey() { + return this.apiKey; + } + + public void setApiKey(String apiKey) { + this.apiKey = apiKey; + } + + public String getBaseUrl() { + return this.baseUrl; + } + + public void setBaseUrl(String baseUrl) { + this.baseUrl = baseUrl; + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/resources/META-INF/additional-spring-configuration-metadata.json b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/resources/META-INF/additional-spring-configuration-metadata.json new file mode 100644 index 00000000000..5952657975c --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/resources/META-INF/additional-spring-configuration-metadata.json @@ -0,0 +1,11 @@ +{ + "groups": [ + { + "name": "spring.ai.cohere.chat.options.tool-choice", + "type": "org.springframework.ai.cohere.api.CohereApi$ChatCompletionRequest$ToolChoice", + "sourceType": "org.springframework.ai.cohere.chat.CohereChatOptions" + } + ], + "properties": [], + "hints": [] +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports new file mode 100644 index 00000000000..1643591155e --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -0,0 +1,18 @@ +# +# Copyright 2025-2025 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +org.springframework.ai.cohere.autoconfigure.CohereChatAutoConfiguration +org.springframework.ai.cohere.autoconfigure.CohereEmbeddingAutoConfiguration +org.springframework.ai.cohere.autoconfigure.CohereMultimodalEmbeddingAutoConfiguration diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/test/java/org/springframework/ai/cohere/autoconfigure/CohereAutoConfigurationIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/test/java/org/springframework/ai/cohere/autoconfigure/CohereAutoConfigurationIT.java new file mode 100644 index 00000000000..d8ec9261f35 --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/test/java/org/springframework/ai/cohere/autoconfigure/CohereAutoConfigurationIT.java @@ -0,0 +1,127 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.autoconfigure; + +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + +import org.springframework.ai.cohere.chat.CohereChatModel; +import org.springframework.ai.cohere.embedding.CohereEmbeddingModel; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.DocumentEmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.utils.SpringAiTestAutoConfigurations; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Ricken Bazolo + */ +@EnabledIfEnvironmentVariable(named = "COHERE_API_KEY", matches = ".*") +public class CohereAutoConfigurationIT { + + private static final Log logger = LogFactory.getLog(CohereAutoConfigurationIT.class); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.cohere.apiKey=" + System.getenv("COHERE_API_KEY")) + .withConfiguration(SpringAiTestAutoConfigurations.of(CohereChatAutoConfiguration.class)); + + @Test + void generate() { + this.contextRunner.withConfiguration(AutoConfigurations.of(CohereChatAutoConfiguration.class)).run(context -> { + CohereChatModel chatModel = context.getBean(CohereChatModel.class); + String response = chatModel.call("Hello"); + assertThat(response).isNotEmpty(); + logger.info("Response: " + response); + }); + } + + @Test + void embedding() { + this.contextRunner.withConfiguration(SpringAiTestAutoConfigurations.of(CohereEmbeddingAutoConfiguration.class)) + .run(context -> { + CohereEmbeddingModel embeddingModel = context.getBean(CohereEmbeddingModel.class); + + EmbeddingResponse embeddingResponse = embeddingModel + .embedForResponse(List.of("Hello World", "World is big and salvation is near")); + assertThat(embeddingResponse.getResults()).hasSize(2); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0); + assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1); + + assertThat(embeddingModel.dimensions()).isEqualTo(1536); + }); + } + + @Test + void generateStreaming() { + this.contextRunner.withConfiguration(SpringAiTestAutoConfigurations.of(CohereChatAutoConfiguration.class)) + .run(context -> { + CohereChatModel chatModel = context.getBean(CohereChatModel.class); + Flux responseFlux = chatModel + .stream(new org.springframework.ai.chat.prompt.Prompt( + new org.springframework.ai.chat.messages.UserMessage("Hello"))); + String response = responseFlux.collectList() + .block() + .stream() + .map(chatResponse -> chatResponse.getResults().get(0).getOutput().getText()) + .collect(java.util.stream.Collectors.joining()); + + assertThat(response).isNotEmpty(); + logger.info("Response: " + response); + }); + } + + @Test + public void multimodalEmbedding() { + this.contextRunner + .withConfiguration(SpringAiTestAutoConfigurations.of(CohereMultimodalEmbeddingAutoConfiguration.class)) + .run(context -> { + var multimodalEmbeddingProperties = context.getBean(CohereMultimodalEmbeddingProperties.class); + + assertThat(multimodalEmbeddingProperties).isNotNull(); + + var multiModelEmbeddingModel = context + .getBean(org.springframework.ai.cohere.embedding.CohereMultimodalEmbeddingModel.class); + + assertThat(multiModelEmbeddingModel).isNotNull(); + + var document = new Document("Hello World"); + + DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(List.of(document), + EmbeddingOptions.builder().build()); + + EmbeddingResponse embeddingResponse = multiModelEmbeddingModel.call(embeddingRequest); + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0)).isNotNull(); + assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1536); + + assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1536); + + }); + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/test/java/org/springframework/ai/cohere/autoconfigure/CohereModelConfigurationTests.java b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/test/java/org/springframework/ai/cohere/autoconfigure/CohereModelConfigurationTests.java new file mode 100644 index 00000000000..79ad7a6e21f --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/test/java/org/springframework/ai/cohere/autoconfigure/CohereModelConfigurationTests.java @@ -0,0 +1,106 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.autoconfigure; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.cohere.chat.CohereChatModel; +import org.springframework.ai.cohere.embedding.CohereEmbeddingModel; +import org.springframework.ai.cohere.embedding.CohereMultimodalEmbeddingModel; +import org.springframework.ai.utils.SpringAiTestAutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit Tests for Cohere auto-configurations conditional enabling of models. + * + * @author Ricken Bazolo + */ +public class CohereModelConfigurationTests { + + private final ApplicationContextRunner chatContextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.cohere.apiKey=" + System.getenv("COHERE_API_KEY")) + .withConfiguration(SpringAiTestAutoConfigurations.of(CohereChatAutoConfiguration.class)); + + private final ApplicationContextRunner embeddingContextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.cohere.apiKey=" + System.getenv("COHERE_API_KEY")) + .withConfiguration(SpringAiTestAutoConfigurations.of(CohereEmbeddingAutoConfiguration.class)); + + private final ApplicationContextRunner embeddingMultimodalContextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.cohere.apiKey=" + System.getenv("COHERE_API_KEY")) + .withConfiguration(SpringAiTestAutoConfigurations.of(CohereMultimodalEmbeddingAutoConfiguration.class)); + + @Test + void chatModelActivation() { + this.chatContextRunner.run(context -> { + assertThat(context.getBeansOfType(CohereChatProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(CohereChatModel.class)).isNotEmpty(); + assertThat(context.getBeansOfType(CohereEmbeddingProperties.class)).isEmpty(); + assertThat(context.getBeansOfType(CohereEmbeddingModel.class)).isEmpty(); + }); + + this.chatContextRunner.withPropertyValues("spring.ai.model.chat=none", "spring.ai.model.embedding=none") + .run(context -> { + assertThat(context.getBeansOfType(CohereChatProperties.class)).isEmpty(); + assertThat(context.getBeansOfType(CohereChatModel.class)).isEmpty(); + }); + + this.chatContextRunner.withPropertyValues("spring.ai.model.chat=cohere", "spring.ai.model.embedding=none") + .run(context -> { + assertThat(context.getBeansOfType(CohereChatProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(CohereChatModel.class)).isNotEmpty(); + assertThat(context.getBeansOfType(CohereEmbeddingProperties.class)).isEmpty(); + assertThat(context.getBeansOfType(CohereEmbeddingModel.class)).isEmpty(); + }); + } + + @Test + void embeddingModelActivation() { + this.embeddingContextRunner + .run(context -> assertThat(context.getBeansOfType(CohereEmbeddingModel.class)).isNotEmpty()); + + this.embeddingContextRunner.withPropertyValues("spring.ai.model.embedding=none").run(context -> { + assertThat(context.getBeansOfType(CohereEmbeddingProperties.class)).isEmpty(); + assertThat(context.getBeansOfType(CohereEmbeddingModel.class)).isEmpty(); + }); + + this.embeddingContextRunner.withPropertyValues("spring.ai.model.embedding=cohere").run(context -> { + assertThat(context.getBeansOfType(CohereEmbeddingProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(CohereEmbeddingModel.class)).isNotEmpty(); + }); + } + + @Test + void multimodalEmbeddingActivation() { + this.embeddingMultimodalContextRunner + .run(context -> assertThat(context.getBeansOfType(CohereMultimodalEmbeddingModel.class)).isNotEmpty()); + + this.embeddingMultimodalContextRunner.withPropertyValues("spring.ai.model.embedding.multimodal=none") + .run(context -> { + assertThat(context.getBeansOfType(CohereMultimodalEmbeddingProperties.class)).isEmpty(); + assertThat(context.getBeansOfType(CohereMultimodalEmbeddingModel.class)).isEmpty(); + }); + + this.embeddingMultimodalContextRunner.withPropertyValues("spring.ai.model.embedding.multimodal=cohere") + .run(context -> { + assertThat(context.getBeansOfType(CohereMultimodalEmbeddingProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(CohereMultimodalEmbeddingModel.class)).isNotEmpty(); + }); + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/test/java/org/springframework/ai/cohere/autoconfigure/CoherePropertiesTests.java b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/test/java/org/springframework/ai/cohere/autoconfigure/CoherePropertiesTests.java new file mode 100644 index 00000000000..ff7f6fd88b4 --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/test/java/org/springframework/ai/cohere/autoconfigure/CoherePropertiesTests.java @@ -0,0 +1,148 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.autoconfigure; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.cohere.api.CohereApi; +import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration; +import org.springframework.ai.utils.SpringAiTestAutoConfigurations; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit Tests for {@link CohereCommonProperties}. + */ +public class CoherePropertiesTests { + + @Test + public void chatOptionsTest() { + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.cohere.base-url=TEST_BASE_URL", "spring.ai.cohere.api-key=abc123", + "spring.ai.cohere.chat.options.tools[0].function.name=myFunction1", + "spring.ai.cohere.chat.options.tools[0].function.description=function description", + "spring.ai.cohere.chat.options.tools[0].function.jsonSchema=" + """ + { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "lat": { + "type": "number", + "description": "The city latitude" + }, + "lon": { + "type": "number", + "description": "The city longitude" + }, + "unit": { + "type": "string", + "enum": ["c", "f"] + } + }, + "required": ["location", "lat", "lon", "unit"] + } + """) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, CohereChatAutoConfiguration.class)) + .run(context -> { + + var chatProperties = context.getBean(CohereChatProperties.class); + + var tool = chatProperties.getOptions().getTools().get(0); + assertThat(tool.getType()).isEqualTo(CohereApi.FunctionTool.Type.FUNCTION); + var function = tool.getFunction(); + assertThat(function.getName()).isEqualTo("myFunction1"); + assertThat(function.getDescription()).isEqualTo("function description"); + assertThat(function.getParameters()).isNotEmpty(); + }); + } + + @Test + public void embeddingProperties() { + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.cohere.base-url=TEST_BASE_URL", "spring.ai.cohere.api-key=abc123", + "spring.ai.cohere.embedding.options.model=MODEL_XYZ") + .withConfiguration(SpringAiTestAutoConfigurations.of(CohereEmbeddingAutoConfiguration.class)) + .run(context -> { + var embeddingProperties = context.getBean(CohereEmbeddingProperties.class); + var connectionProperties = context.getBean(CohereCommonProperties.class); + + assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + + assertThat(embeddingProperties.getApiKey()).isNull(); + assertThat(embeddingProperties.getBaseUrl()).isEqualTo(CohereCommonProperties.DEFAULT_BASE_URL); + + assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + }); + } + + @Test + public void embeddingOverrideConnectionProperties() { + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.cohere.base-url=TEST_BASE_URL", "spring.ai.cohere.api-key=abc123", + "spring.ai.cohere.embedding.base-url=TEST_BASE_URL2", "spring.ai.cohere.embedding.api-key=456", + "spring.ai.cohere.embedding.options.model=MODEL_XYZ") + .withConfiguration(SpringAiTestAutoConfigurations.of(CohereEmbeddingAutoConfiguration.class)) + .run(context -> { + var embeddingProperties = context.getBean(CohereEmbeddingProperties.class); + var connectionProperties = context.getBean(CohereCommonProperties.class); + + assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + + assertThat(embeddingProperties.getApiKey()).isEqualTo("456"); + assertThat(embeddingProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2"); + + assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + }); + } + + @Test + public void embeddingOptionsTest() { + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.cohere.api-key=API_KEY", "spring.ai.cohere.base-url=TEST_BASE_URL", + "spring.ai.cohere.embedding.options.model=MODEL_XYZ", + "spring.ai.cohere.embedding.options.embedding-types[0]=FLOAT", + "spring.ai.cohere.embedding.options.input-type=search_document", + "spring.ai.cohere.embedding.options.truncate=END") + .withConfiguration(SpringAiTestAutoConfigurations.of(CohereEmbeddingAutoConfiguration.class)) + .run(context -> { + var connectionProperties = context.getBean(CohereCommonProperties.class); + var embeddingProperties = context.getBean(CohereEmbeddingProperties.class); + + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY"); + + assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + assertThat(embeddingProperties.getOptions().getEmbeddingTypes().get(0).name()).isEqualTo("FLOAT"); + assertThat(embeddingProperties.getOptions().getTruncate().name()).isEqualTo("END"); + assertThat(embeddingProperties.getOptions().getInputType().name()).isEqualTo("SEARCH_DOCUMENT"); + }); + } + +} diff --git a/models/spring-ai-cohere/README.md b/models/spring-ai-cohere/README.md new file mode 100644 index 00000000000..e69de29bb2d diff --git a/models/spring-ai-cohere/pom.xml b/models/spring-ai-cohere/pom.xml new file mode 100644 index 00000000000..4b4d97dc4c8 --- /dev/null +++ b/models/spring-ai-cohere/pom.xml @@ -0,0 +1,90 @@ + + + + + 4.0.0 + + org.springframework.ai + spring-ai-parent + 2.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-cohere + jar + Spring AI Model - Cohere + Cohere models support + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + + + + + org.springframework.ai + spring-ai-model + ${project.parent.version} + + + + org.springframework.ai + spring-ai-retry + ${project.parent.version} + + + + + org.springframework + spring-context-support + + + + org.springframework + spring-webflux + + + + org.slf4j + slf4j-api + + + + + org.springframework.ai + spring-ai-test + ${project.version} + test + + + + io.micrometer + micrometer-observation-test + test + + + + + diff --git a/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/aot/CohereRuntimeHints.java b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/aot/CohereRuntimeHints.java new file mode 100644 index 00000000000..b25b5a4843d --- /dev/null +++ b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/aot/CohereRuntimeHints.java @@ -0,0 +1,42 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.aot; + +import org.springframework.aot.hint.MemberCategory; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; + +import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; + +/** + * The CohereRuntimeHints class is responsible for registering runtime hints for Cohere AI + * API classes. + * + * @author Ricken Bazolo + */ +public class CohereRuntimeHints implements RuntimeHintsRegistrar { + + @Override + public void registerHints(final RuntimeHints hints, final ClassLoader classLoader) { + var mcs = MemberCategory.values(); + + for (var tr : findJsonAnnotatedClassesInPackage("org.springframework.ai.cohere")) { + hints.reflection().registerType(tr, mcs); + } + } + +} diff --git a/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/api/CohereApi.java b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/api/CohereApi.java new file mode 100644 index 00000000000..f79808832da --- /dev/null +++ b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/api/CohereApi.java @@ -0,0 +1,1332 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.api; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.Consumer; +import java.util.function.Predicate; + +import com.fasterxml.jackson.annotation.JsonFormat; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.ai.model.ChatModelDescription; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.observation.conventions.AiProvider; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; + +/** + * Java Client library for Cohere Platform. Provides implementation for the + * Chat and + * Chat Stream + * Embedding API. + *

+ * Implements Synchronous and Streaming chat completion and supports latest + * Function Calling features. + *

+ * + * @author Ricken Bazolo + */ +public class CohereApi { + + public static final String PROVIDER_NAME = AiProvider.COHERE.value(); + + private static final String DEFAULT_BASE_URL = "https://api.cohere.com"; + + private static final Predicate SSE_DONE_PREDICATE = "[DONE]"::equals; + + private final RestClient restClient; + + private final WebClient webClient; + + private final CohereStreamFunctionCallingHelper chunkMerger = new CohereStreamFunctionCallingHelper(); + + /** + * Create a new client api with DEFAULT_BASE_URL + * @param cohereApiKey Cohere api Key. + */ + public CohereApi(String cohereApiKey) { + this(DEFAULT_BASE_URL, cohereApiKey); + } + + /** + * Create a new client api. + * @param baseUrl api base URL. + * @param cohereApiKey Cohere api Key. + */ + public CohereApi(String baseUrl, String cohereApiKey) { + this(baseUrl, cohereApiKey, RestClient.builder(), WebClient.builder(), + RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); + } + + /** + * Create a new client api. + * @param baseUrl api base URL. + * @param cohereApiKey Cohere api Key. + * @param restClientBuilder RestClient builder. + * @param responseErrorHandler Response error handler. + */ + public CohereApi(String baseUrl, String cohereApiKey, RestClient.Builder restClientBuilder, + WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { + + Consumer jsonContentHeaders = headers -> { + headers.setBearerAuth(cohereApiKey); + headers.setContentType(MediaType.APPLICATION_JSON); + }; + + this.restClient = restClientBuilder.baseUrl(baseUrl) + .defaultHeaders(jsonContentHeaders) + .defaultStatusHandler(responseErrorHandler) + .build(); + + this.webClient = webClientBuilder.clone().baseUrl(baseUrl).defaultHeaders(jsonContentHeaders).build(); + } + + /** + * Creates a model response for the given chat conversation. + * @param chatRequest The chat completion request. + * @return Entity response with {@link ChatCompletion} as a body and HTTP status code + * and headers. + */ + public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); + + return this.restClient.post().uri("/v2/chat/").body(chatRequest).retrieve().toEntity(ChatCompletion.class); + } + + /** + * Creates an embedding vector representing the input text, token array, or images. + * @param embeddingRequest The embedding request. + * @return Returns {@link EmbeddingResponse} with embeddings data. + * @param Type of the entity in the data list. Can be a {@link String} or + * {@link List} of tokens (e.g. Integers). For embedding multiple inputs in a single + * request, You can pass a {@link List} of {@link String} or {@link List} of + * {@link List} of tokens. For example: + * + *
{@code List.of("text1", "text2", "text3")} 
+ */ + public ResponseEntity embeddings(EmbeddingRequest embeddingRequest) { + + Assert.notNull(embeddingRequest, "The request body can not be null."); + + boolean hasTexts = !CollectionUtils.isEmpty(embeddingRequest.texts); + boolean hasImages = !CollectionUtils.isEmpty(embeddingRequest.images); + + Assert.isTrue(hasTexts || hasImages, "Either texts or images must be provided"); + Assert.isTrue(!(hasTexts && hasImages), "Cannot provide both texts and images in the same request"); + + if (hasTexts) { + Assert.isTrue(embeddingRequest.texts.size() <= 96, "The texts list must be 96 items or less"); + } + + if (hasImages) { + Assert.isTrue(embeddingRequest.images.size() <= 1, "Only one image per request is supported"); + } + + return this.restClient.post() + .uri("/v2/embed") + .body(embeddingRequest) + .retrieve() + .toEntity(new ParameterizedTypeReference<>() { + + }); + } + + /** + * Creates a streaming chat response for the given chat conversation. + * @param chatRequest The chat completion request. Must have the stream property set + * to true. + * @return Returns a {@link Flux} stream from chat completion chunks. + */ + public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); + + return this.webClient.post() + .uri("v2/chat") + .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .retrieve() + .bodyToFlux(String.class) + .takeUntil(SSE_DONE_PREDICATE) + .filter(SSE_DONE_PREDICATE.negate()) + .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) + .groupBy(chunk -> chunk.id() != null ? chunk.id() : "no-id") + .flatMap(group -> group.reduce(new ChatCompletionChunk(null, null, null, null), this.chunkMerger::merge) + .filter(chunk -> EventType.MESSAGE_END.value.equals(chunk.type()) + || (chunk.delta() != null && chunk.delta().finishReason() != null))) + .map(this.chunkMerger::sanitizeToolCalls) + .filter(this.chunkMerger::hasValidToolCallsOnly) + .filter(Objects::nonNull); + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating CohereApi instances. + */ + public static class Builder { + + private String baseUrl = DEFAULT_BASE_URL; + + private String apiKey; + + private RestClient.Builder restClientBuilder = RestClient.builder(); + + private WebClient.Builder webClientBuilder = WebClient.builder(); + + private ResponseErrorHandler responseErrorHandler = RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER; + + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + public Builder apiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + public Builder restClientBuilder(RestClient.Builder restClientBuilder) { + this.restClientBuilder = restClientBuilder; + return this; + } + + public Builder webClientBuilder(WebClient.Builder webClientBuilder) { + this.webClientBuilder = webClientBuilder; + return this; + } + + public Builder responseErrorHandler(ResponseErrorHandler responseErrorHandler) { + this.responseErrorHandler = responseErrorHandler; + return this; + } + + public CohereApi build() { + Assert.hasText(this.apiKey, "Cohere API key must be set"); + Assert.hasText(this.baseUrl, "Cohere base URL must be set"); + Assert.notNull(this.restClientBuilder, "RestClient.Builder must not be null"); + Assert.notNull(this.webClientBuilder, "WebClient.Builder must not be null"); + Assert.notNull(this.responseErrorHandler, "ResponseErrorHandler must not be null"); + + return new CohereApi(this.baseUrl, this.apiKey, this.restClientBuilder, this.webClientBuilder, + this.responseErrorHandler); + } + + } + + /** + * List of well-known Cohere chat models. + * + * @see Cohere Models Overview + */ + public enum ChatModel implements ChatModelDescription { + + COMMAND_A("command-a-03-2025"), + + COMMAND_A_REASONING("command-a-reasoning-08-2025"), + + COMMAND_A_TRANSLATE("command-a-translate-08-2025"), + + COMMAND_A_VISION("command-a-vision-07-2025"), + + COMMAND_A_R7B("command-r7b-12-2024"), + + COMMAND_R_PLUS("command-r-plus-08-2024"), + + COMMAND_R("command-r-08-2024"); + + private final String value; + + ChatModel(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + + @Override + public String getName() { + return this.value; + } + + } + + /** + * Usage statistics. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Usage(@JsonProperty("billedUnits") BilledUnits billedUnits, @JsonProperty("tokens") Tokens tokens, + @JsonProperty("cached_tokens") Integer cachedTokens) { + /** + * Bille units + * + * @param inputTokens The number of billed input tokens. + * @param outputTokens The number of billed output tokens. + * @param searchUnits The number of billed search units. + * @param classifications The number of billed classifications units. + */ + public record BilledUnits(@JsonProperty("input_tokens") Integer inputTokens, + @JsonProperty("output_tokens") Integer outputTokens, @JsonProperty("search_units") Double searchUnits, + @JsonProperty("classifications") Double classifications) { + } + + /** + * The Tokens + * + * @param inputTokens The number of tokens used as input to the model. + * @param outputTokens The number of tokens produced by the model. + */ + public record Tokens(@JsonProperty("input_tokens") Integer inputTokens, + @JsonProperty("output_tokens") Integer outputTokens) { + } + } + + /** + * Creates a model request for chat conversation. + * + * @param model The name of a compatible Cohere model or the ID of a fine-tuned model. + * @param messages The prompt(s) to generate completions for, encoded as a list of + * dict with role and rawContent. The first prompt role should be user or system. + * @param tools A list of tools the model may call. Currently, only functions are + * supported as a tool. Use this to provide a list of functions the model may generate + * JSON inputs for. + * @param documents A list of relevant documents that the model can cite to generate a + * more accurate reply. Each document is either a string or document object with + * rawContent and metadata. + * @param citationOptions Options for controlling citation generation. + * @param responseFormat An object specifying the format or schema that the model must + * output. Setting to { "type": "json_object" } enables JSON mode, which guarantees + * the message the model generates is valid JSON. Setting to { "type": "json_object" , + * "json_schema": schema} allows you to ensure the model provides an answer in a very + * specific JSON format by supplying a clear JSON schema. + * @param safetyMode Safety modes are not yet configurable in combination with tools, + * tool_results and documents parameters. + * @param maxTokens The maximum number of tokens to generate in the completion. The + * token count of your prompt plus max_tokens cannot exceed the model's context + * length. + * @param stopSequences A list of tokens that the model should stop generating after. + * If set, + * @param temperature What sampling temperature to use, between 0.0 and 1.0. Higher + * values like 0.8 will make the output more random, while lower values like 0.2 will + * make it more focused and deterministic. We generally recommend altering this or p + * but not both. + * @param seed If specified, the backend will make a best effort to sample tokens + * deterministically, such that repeated requests with the same seed and parameters + * should return the same result. However, determinism cannot be totally guaranteed. + * @param frequencyPenalty Number between 0.0 and 1.0. Used to reduce repetitiveness + * of generated tokens. The higher the value, the stronger a penalty is applied to + * previously present tokens, proportional to how many times they have already + * appeared in the prompt or prior generation. + * @param presencePenalty min value of 0.0, max value of 1.0. Used to reduce + * repetitiveness of generated tokens. Similar to frequency_penalty, except that this + * penalty is applied equally to all tokens that have already appeared, regardless of + * their exact frequencies. + * @param stream When true, the response will be a SSE stream of events. The final + * event will contain the complete response, and will have an event_type of + * "stream-end". + * @param k Ensures that only the top k most likely tokens are considered for + * generation at each step. When k is set to 0, k-sampling is disabled. Defaults to 0, + * min value of 0, max value of 500. + * @param p Ensures that only the most likely tokens, with total probability mass of + * p, are considered for generation at each step. If both k and p are enabled, p acts + * after k. Defaults to 0.75. min value of 0.01, max value of 0.99. + * @param logprobs Defaults to false. When set to true, the log probabilities of the + * generated tokens will be included in the response. + * @param toolChoice Used to control whether or not the model will be forced to use a + * tool when answering. When REQUIRED is specified, the model will be forced to use at + * least one of the user-defined tools, and the tools parameter must be passed in the + * request. When NONE is specified, the model will be forced not to use one of the + * specified tools, and give a direct response. If tool_choice isn’t specified, then + * the model is free to choose whether to use the specified tools or not. + * @param strictTools When set to true, tool calls in the Assistant message will be + * forced to follow the tool definition strictly. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ChatCompletionRequest(@JsonProperty("model") String model, + @JsonProperty("messages") List messages, + @JsonProperty("tools") List tools, @JsonProperty("documents") List documents, + @JsonProperty("citation_options") CitationOptions citationOptions, + @JsonProperty("response_format") ResponseFormat responseFormat, + @JsonProperty("safety_mode") SafetyMode safetyMode, @JsonProperty("max_tokens") Integer maxTokens, + @JsonProperty("stop_sequences") List stopSequences, @JsonProperty("temperature") Double temperature, + @JsonProperty("seed") Integer seed, @JsonProperty("frequency_penalty") Double frequencyPenalty, + @JsonProperty("stream") Boolean stream, @JsonProperty("k") Integer k, @JsonProperty("p") Double p, + @JsonProperty("logprobs") Boolean logprobs, @JsonProperty("tool_choice") ToolChoice toolChoice, + @JsonProperty("strict_tools") Boolean strictTools, + @JsonProperty("presence_penalty") Double presencePenalty) { + + /** + * Shortcut constructor for a chat completion request with the given messages and + * model. + * @param messages The prompt(s) to generate completions for, encoded as a list of + * dict with role and rawContent. The first prompt role should be user or system. + * @param model ID or name of the model to use. + */ + public ChatCompletionRequest(List messages, String model) { + this(model, messages, null, null, new CitationOptions(CitationMode.FAST), null, SafetyMode.CONTEXTUAL, null, + null, 0.3, null, null, false, 0, 0.75, false, null, false, null); + } + + /** + * Shortcut constructor for a chat completion request with the given messages, + * model and temperature. + * @param messages The prompt(s) to generate completions for, encoded as a list of + * dict with role and rawContent. The first prompt role should be user or system. + * @param model ID or model of the model to use. + * @param temperature What sampling temperature to use, between 0.0 and 1.0. + * @param stream Whether to stream back partial progress. If set, tokens will be + * sent + */ + public ChatCompletionRequest(List messages, String model, Double temperature, + boolean stream) { + this(model, messages, null, null, new CitationOptions(CitationMode.FAST), null, SafetyMode.CONTEXTUAL, null, + null, temperature, null, null, stream, 0, 0.75, false, null, false, null); + } + + /** + * Shortcut constructor for a chat completion request with the given messages, + * model and temperature. + * @param messages The prompt(s) to generate completions for, encoded as a list of + * dict with role and rawContent. The first prompt role should be user or system. + * @param model ID of the model to use. + * @param temperature What sampling temperature to use, between 0.0 and 1.0. + * + */ + public ChatCompletionRequest(List messages, String model, Double temperature) { + this(model, messages, null, null, new CitationOptions(CitationMode.FAST), null, SafetyMode.CONTEXTUAL, null, + null, temperature, null, null, false, 0, 0.75, false, null, false, null); + } + + /** + * Shortcut constructor for a chat completion request with the given messages, + * model, tools and tool choice. Streaming is set to false, temperature to 0.8 and + * all other parameters are null. + * @param messages A list of messages comprising the conversation so far. + * @param model ID of the model to use. + * @param tools A list of tools the model may call. Currently, only functions are + * supported as a tool. + * @param toolChoice Controls which (if any) function is called by the model. + */ + public ChatCompletionRequest(List messages, String model, List tools, + ToolChoice toolChoice) { + this(model, messages, tools, null, new CitationOptions(CitationMode.FAST), null, SafetyMode.CONTEXTUAL, + null, null, 0.75, null, null, false, 0, 0.75, false, toolChoice, false, null); + } + + /** + * Shortcut constructor for a chat completion request with the given messages and + * stream. + */ + public ChatCompletionRequest(List messages, Boolean stream) { + this(null, messages, null, null, new CitationOptions(CitationMode.FAST), null, SafetyMode.CONTEXTUAL, null, + null, 0.75, null, null, stream, 0, 0.75, false, null, false, null); + } + + /** + * An object specifying the format that the model must output. + * + * @param type Must be one of 'text' or 'json_object'. + * @param jsonSchema A specific JSON schema to match, if 'type' is 'json_object'. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ResponseFormat(@JsonProperty("type") String type, + @JsonProperty("json_schema") Map jsonSchema) { + } + + /** + * Specifies a tool the model should use + */ + public enum ToolChoice { + + REQUIRED, NONE + + } + + } + + /** + * Message comprising the conversation. A message from the assistant role can contain + * text and tool call information. + * + * @param role The role of the messages author. Could be one of the {@link Role} types + * "assistant". + * @param toolCalls The tool calls generated by the model, such as function calls. + * Applicable only for {@link Role#ASSISTANT} role and null otherwise. + * @param toolPlan A chain-of-thought style reflection and plan that the model + * generates when working with Tools. + * @param rawContent The contents of the message. Can be either a {@link MediaContent} + * or a {@link MessageContent}. + * @param citations Tool call that this message is responding to. Only applicable for + * the {@link ChatCompletionFinishReason#TOOL_CALL} role and null otherwise. + */ + public record ChatCompletionMessage(@JsonProperty("content") Object rawContent, @JsonProperty("role") Role role, + @JsonProperty("tool_plan") String toolPlan, + @JsonFormat( + with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) @JsonProperty("tool_calls") List toolCalls, + @JsonFormat( + with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) @JsonProperty("citations") List citations, + @JsonProperty("tool_call_id") String toolCallId) { + + public ChatCompletionMessage(Object content, Role role) { + this(content, role, null, null, null, null); + } + + public ChatCompletionMessage(Object content, Role role, List toolCalls) { + this(content, role, null, toolCalls, null, null); + } + + public ChatCompletionMessage(Object content, Role role, List toolCalls, String toolPlan) { + this(content, role, toolPlan, toolCalls, null, null); + } + + public ChatCompletionMessage(Object content, Role role, String toolCallId) { + this(content, role, null, null, null, toolCallId); + } + + /** + * Get message content as String. + */ + public String content() { + if (this.rawContent == null) { + return null; + } + if (this.rawContent instanceof String text) { + return text; + } + throw new IllegalStateException("The content is not a string!"); + } + + /** + * An array of rawContent parts with a defined type. Each MediaContent can be of + * either "text" or "image_url" type. Only one option allowed. + * + * @param type Content type, each can be of type text or image_url. + * @param text The text rawContent of the message. + * @param imageUrl The image rawContent of the message. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record MediaContent(@JsonProperty("type") String type, @JsonProperty("text") String text, + @JsonProperty("image_url") ImageUrl imageUrl) { + + /** + * Shortcut constructor for a text rawContent. + * @param text The text rawContent of the message. + */ + public MediaContent(String text) { + this("text", text, null); + } + + /** + * Shortcut constructor for an image rawContent. + * @param imageUrl The image rawContent of the message. + */ + public MediaContent(ImageUrl imageUrl) { + this("image_url", null, imageUrl); + } + + /** + * The level of detail for processing the image. + */ + public enum DetailLevel { + + @JsonProperty("low") + LOW, + + @JsonProperty("high") + HIGH, + + @JsonProperty("auto") + AUTO + + } + + /** + * Shortcut constructor for an image rawContent. + * + * @param url Either a URL of the image or the base64 encoded image data. The + * base64 encoded image data must have a special prefix in the following + * format: "data:{mimetype};base64,{base64-encoded-image-data}". + * @param detail The level of detail for processing the image. Can be "low", + * "high", or "auto". Defaults to "auto" if not specified. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ImageUrl(@JsonProperty("url") String url, @JsonProperty("detail") DetailLevel detail) { + + public ImageUrl(String url) { + this(url, DetailLevel.AUTO); + } + + } + } + + /** + * Message rawContent that can be either a text or a value. + * + * @param type The type of the message rawContent, such as "text" or "thinking". + * @param text The text rawContent of the message. + * @param value The value of the thinking, which can be any object. + */ + public record MessageContent(@JsonProperty("type") String type, @JsonProperty("text") String text, + @JsonProperty("value") Object value) { + } + + /** + * The role of the author of this message. + */ + public enum Role { + + /** + * User message. + */ + @JsonProperty("user") + USER, + /** + * Assistant message. + */ + @JsonProperty("assistant") + ASSISTANT, + /** + * System message. + */ + @JsonProperty("system") + SYSTEM, + /** + * Tool message. + */ + @JsonProperty("tool") + TOOL + + } + + /** + * The relevant tool call. + * + * @param id The ID of the tool call. This ID must be referenced when you submit + * the tool outputs in using the Submit tool outputs to run endpoint. + * @param type The type of tool call the output is required for. For now, this is + * always function. + * @param function The function definition. + * @param index The index of the tool call in the list of tool calls. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ToolCall(@JsonProperty("id") String id, @JsonProperty("type") String type, + @JsonProperty("function") ChatCompletionFunction function, @JsonProperty("index") Integer index) { + } + + /** + * The function definition. + * + * @param name The name of the function. + * @param arguments The arguments that the model expects you to pass to the + * function. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ChatCompletionFunction(@JsonProperty("name") String name, + @JsonProperty("arguments") String arguments) { + } + + public record ChatCompletionCitation( + /** + * Start index of the cited snippet in the original source text. + */ + @JsonProperty("start") Integer start, + /** + * End index of the cited snippet in the original source text. + */ + @JsonProperty("end") Integer end, + /** + * Text snippet that is being cited. + */ + @JsonProperty("text") String text, @JsonProperty("sources") List sources, + @JsonProperty("type") Type type) { + /** + * The type of citation which indicates what part of the response the citation + * is for. + */ + public enum Type { + + TEXT_CONTENT, PLAN + + } + + /** + * @param type Tool or A document source object containing the unique + * identifier of the document and the document itself. + * @param id The unique identifier of the document + * @param toolOutput map from strings to any Optional if type == tool + * @param document map from strings to any Optional if type == document + */ + public record Source(@JsonProperty("type") String type, @JsonProperty("id") String id, + @JsonProperty("tool_output") Map toolOutput, + @JsonProperty("document") Map document) { + } + } + + public record Provider(@JsonProperty("content") List content, @JsonProperty("role") Role role, + @JsonProperty("tool_plan") String toolPlan, @JsonProperty("tool_calls") List toolCalls, + @JsonProperty("citations") List citations) { + } + } + + /** + * Used to select the safety instruction inserted into the prompt. Defaults to + * CONTEXTUAL. When OFF is specified, the safety instruction will be omitted. Safety + * modes are not yet configurable in combination with tools, tool_results and + * documents parameters. Note: This parameter is only compatible newer Cohere models, + * starting with Command R 08-2024 and Command R+ 08-2024. Note: command-r7b-12-2024 + * and newer models only support "CONTEXTUAL" and "STRICT" modes. + */ + public enum SafetyMode { + + CONTEXTUAL, STRICT, OFF + + } + + /** + * Options for controlling citation generation. Defaults to "accurate". Dictates the + * approach taken to generating citations as part of the RAG flow by allowing the user + * to specify whether they want "accurate" results, "fast" results or no results. + * Note: command-r7b-12-2024 and command-a-03-2025 only support "fast" and "off" + * modes. The default is "fast". + */ + public record CitationOptions(@JsonProperty("mode") CitationMode mode) { + } + + /** + * Options for controlling citation generation. Defaults to "accurate". Dictates the + * approach taken to generating citations as part of the RAG flow by allowing the user + * to specify whether they want "accurate" results, "fast" results or no results. + * Note: command-r7b-12-2024 and command-a-03-2025 only support "fast" and "off" + * modes. The default is "fast". + */ + public enum CitationMode { + + FAST, ACCURATE, OFF + + } + + /** + * relevant documents that the model can cite to generate a more accurate reply. Each + * document is either a string or document object with rawContent and metadata. + * + * @param id An optional Unique identifier for this document which will be referenced + * in citations. If not provided an ID will be automatically generated. + * @param data A relevant document that the model can cite to generate a more accurate + * reply. Each document is a string-any dictionary. + */ + public record Document(@JsonProperty("id") String id, @JsonProperty("data") String data) { + } + + /** + * Represents a tool the model may call. Currently, only functions are supported as a + * tool. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public static class FunctionTool { + + // The type of the tool. Currently, only 'function' is supported. + @JsonProperty("type") + Type type = Type.FUNCTION; + + // The function definition. + @JsonProperty("function") + Function function; + + public FunctionTool() { + + } + + /** + * Create a tool of type 'function' and the given function definition. + * @param function function definition. + */ + public FunctionTool(Function function) { + this(Type.FUNCTION, function); + } + + public FunctionTool(Type type, Function function) { + this.type = type; + this.function = function; + } + + public Type getType() { + return this.type; + } + + public Function getFunction() { + return this.function; + } + + public void setType(Type type) { + this.type = type; + } + + public void setFunction(Function function) { + this.function = function; + } + + /** + * Create a tool of type 'function' and the given function definition. + */ + public enum Type { + + /** + * Function tool type. + */ + @JsonProperty("function") + FUNCTION + + } + + /** + * Function definition. + */ + public static class Function { + + @JsonProperty("description") + private String description; + + @JsonProperty("name") + private String name; + + @JsonProperty("parameters") + private Map parameters; + + @JsonIgnore + private String jsonSchema; + + private Function() { + + } + + /** + * Create tool function definition. + * @param description A description of what the function does, used by the + * model to choose when and how to call the function. + * @param name The name of the function to be called. Must be a-z, A-Z, 0-9, + * or contain underscores and dashes, with a maximum length of 64. + * @param parameters The parameters the functions accepts, described as a JSON + * Schema object. To describe a function that accepts no parameters, provide + * the value {"type": "object", "properties": {}}. + */ + public Function(String description, String name, Map parameters) { + this.description = description; + this.name = name; + this.parameters = parameters; + } + + /** + * Create tool function definition. + * @param description tool function description. + * @param name tool function name. + * @param jsonSchema tool function schema as json. + */ + public Function(String description, String name, String jsonSchema) { + this(description, name, ModelOptionsUtils.jsonToMap(jsonSchema)); + } + + public String getDescription() { + return this.description; + } + + public String getName() { + return this.name; + } + + public Map getParameters() { + return this.parameters; + } + + public void setDescription(String description) { + this.description = description; + } + + public void setName(String name) { + this.name = name; + } + + public void setParameters(Map parameters) { + this.parameters = parameters; + } + + public String getJsonSchema() { + return this.jsonSchema; + } + + public void setJsonSchema(String jsonSchema) { + this.jsonSchema = jsonSchema; + if (jsonSchema != null) { + this.parameters = ModelOptionsUtils.jsonToMap(jsonSchema); + } + } + + } + + } + + /** + * Represents a chat completion response returned by model, based on the provided + * input. + * + * @param id A unique identifier for the chat completion. + * @param finishReason The reason the model stopped generating tokens. + * @param message A chat completion message generated by streamed model responses. + * @param logprobs Log probability information for the choice. + * @param usage Usage statistics for the completion request. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ChatCompletion(@JsonProperty("id") String id, + @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason, + @JsonProperty("message") ChatCompletionMessage.Provider message, + @JsonProperty("logprobs") LogProbs logprobs, @JsonProperty("usage") Usage usage) { + } + + /** + * The reason the model stopped generating tokens. + */ + public enum ChatCompletionFinishReason { + + /** + * The model finished sending a complete message. + */ + COMPLETE, + + /** + * One of the provided stop_sequence entries was reached in the model’s + * generation. + */ + STOP_SEQUENCE, + + /** + * The number of generated tokens exceeded the model’s context length or the value + * specified via the max_tokens parameter. + */ + MAX_TOKENS, + + /** + * The model generated a Tool Call and is expecting a Tool Message in return + */ + TOOL_CALL, + + /** + * The model called a tool. + */ + @JsonProperty("tool_calls") + TOOL_CALLS, + + /** + * The generation failed due to an internal error + */ + ERROR + + } + + /** + * Log probability information + * + * @param tokenIds The token ids of each token used to construct the text chunk. + * @param text The text chunk for which the log probabilities was calculated. + * @param logprobs The log probability of each token used to construct the text chunk. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record LogProbs(@JsonProperty("token_ids") List tokenIds, @JsonProperty("text") String text, + @JsonProperty("logprobs") List logprobs) { + + } + + /** + * Helper factory that creates a tool_choice of type 'REQUIRED', 'NONE' or selected + * function by name. + */ + public static class ToolChoiceBuilder { + + public static final String NONE = "NONE"; + + public static final String REQUIRED = "REQUIRED"; + + /** + * Specifying a particular function forces the model to call that function. + */ + public static Object FUNCTION(String functionName) { + return Map.of("type", "function", "function", Map.of("name", functionName)); + } + + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ChatCompletionChunk( + // @formatter:off + @JsonProperty("id") String id, + @JsonProperty("type") String type, + @JsonProperty("index") Integer index, + @JsonProperty("delta") ChunkDelta delta) { + // @formatter:on + + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ChunkDelta( + // @formatter:off + @JsonProperty("message") ChatCompletionMessage message, + @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason, + @JsonProperty("usage") Usage usage) { + // @formatter:on + } + + } + + /** + * List of well-known Cohere embedding models. + * + * @see Cohere Models Overview + */ + public enum EmbeddingModel { + + // @formatter:off + + /** + * A model that allows for text and images to be classified or turned into embeddings + * dimensional - [256, 512, 1024, 1536 (default)] + */ + EMBED_V4("embed-v4.0"), + /** + * Embed v3 Multilingual model for text embeddings. + * Produces 1024-dimensional embeddings suitable for multilingual semantic search, + * clustering, and other text similarity tasks. + */ + EMBED_MULTILINGUAL_V3("embed-multilingual-v3.0"), + + /** + * Embed v3 English model for text embeddings. + * Produces 1024-dimensional embeddings optimized for English text. + */ + EMBED_ENGLISH_V3("embed-english-v3.0"), + + /** + * Embed v3 Multilingual Light model. + * Smaller and faster variant with 1024 dimensions. + */ + EMBED_MULTILINGUAL_LIGHT_V3("embed-multilingual-light-v3.0"), + + /** + * Embed v3 English Light model. + * Smaller and faster English-only variant with 1024 dimensions. + */ + EMBED_ENGLISH_LIGHT_V3("embed-english-light-v3.0"); + // @formatter:on + + private final String value; + + EmbeddingModel(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + + } + + /** + * Embedding type + */ + public enum EmbeddingType { + + /** + * Use this when you want to get back the default float embeddings. Supported with + * all Embed models. + */ + @JsonProperty("float") + FLOAT, + + /** + * Use this when you want to get back signed int8 embeddings. Supported with Embed + * v3.0 and newer Embed models. + */ + @JsonProperty("int8") + INT8, + + /** + * Use this when you want to get back unsigned int8 embeddings. Supported with + * Embed v3.0 and newer Embed models. + */ + @JsonProperty("uint8") + UINT8, + + /** + * Use this when you want to get back signed binary embeddings. Supported with + * Embed v3.0 and newer Embed models. + */ + @JsonProperty("binary") + BINARY, + + /** + * Use this when you want to get back unsigned binary embeddings. Supported with + * Embed v3.0 and newer Embed models. + */ + @JsonProperty("ubinary") + UBINARY, + + /** + * Use this when you want to get back base64 embeddings. Supported with Embed v3.0 + * and newer Embed models. + */ + @JsonProperty("base64") + BASE64 + + } + + /** + * Embedding request. + * + * @param texts An array of strings to embed. + * @param images An array of images to embed as data URIs. + * @param model The model to use for embedding. + * @param inputType The type of input (search_document, search_query, classification, + * clustering, image). + * @param embeddingTypes The types of embeddings to return (float, int8, uint8, + * binary, ubinary). + * @param truncate How to handle inputs longer than the maximum token length (NONE, + * START, END). + * @param Type of the input (String or List of tokens). + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record EmbeddingRequest( + // @formatter:off + @JsonProperty("texts") List texts, + @JsonProperty("images") List images, + @JsonProperty("model") String model, + @JsonProperty("input_type") InputType inputType, + @JsonProperty("embedding_types") List embeddingTypes, + @JsonProperty("truncate") Truncate truncate) { + // @formatter:on + + public static Builder builder() { + return new Builder<>(); + } + + public static final class Builder { + + private String model = EmbeddingModel.EMBED_V4.getValue(); + + private List texts; + + private List images; + + private InputType inputType = InputType.SEARCH_DOCUMENT; + + private List embeddingTypes = List.of(EmbeddingType.FLOAT); + + private Truncate truncate = Truncate.END; + + public Builder model(String model) { + this.model = model; + return this; + } + + public Builder texts(Object raw) { + if (raw == null) { + this.texts = null; + return this; + } + + if (raw instanceof List list) { + this.texts = (List) list; + } + else { + this.texts = List.of((T) raw); + } + return this; + } + + public Builder images(List images) { + this.images = images; + return this; + } + + public Builder inputType(InputType inputType) { + this.inputType = inputType; + return this; + } + + public Builder embeddingTypes(List embeddingTypes) { + this.embeddingTypes = embeddingTypes; + return this; + } + + public Builder truncate(Truncate truncate) { + this.truncate = truncate; + return this; + } + + public EmbeddingRequest build() { + return new EmbeddingRequest<>(this.texts, this.images, this.model, this.inputType, this.embeddingTypes, + this.truncate); + } + + } + + /** + * Input type for embeddings. + */ + public enum InputType { + + // @formatter:off + @JsonProperty("search_document") + SEARCH_DOCUMENT, + @JsonProperty("search_query") + SEARCH_QUERY, + @JsonProperty("classification") + CLASSIFICATION, + @JsonProperty("clustering") + CLUSTERING, + @JsonProperty("image") + IMAGE + // @formatter:on + + } + + /** + * Truncation strategy for inputs longer than maximum token length. + */ + public enum Truncate { + + // @formatter:off + @JsonProperty("NONE") + NONE, + @JsonProperty("START") + START, + @JsonProperty("END") + END + // @formatter:on + + } + + } + + /** + * Embedding response. + * + * @param id Unique identifier for the embedding request. + * @param embeddings The embeddings + * @param texts The texts that were embedded. + * @param responseType The type of response ("embeddings_floats" or + * "embeddings_by_type"). + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record EmbeddingResponse( + // @formatter:off + @JsonProperty("id") String id, + @JsonProperty("embeddings") Object embeddings, + @JsonProperty("texts") List texts, + @JsonProperty("response_type") String responseType) { + // @formatter:on + + /** + * Extracts float embeddings from the response. Handles both response formats: - + * "embeddings_floats": embeddings is List<List<Double>> - + * "embeddings_by_type": embeddings is Map with "float" key containing + * List<List<Double>> + * @return List of float arrays representing the embeddings + */ + @JsonIgnore + @SuppressWarnings("unchecked") + public List getFloatEmbeddings() { + if (this.embeddings == null) { + return List.of(); + } + + // Handle "embeddings_floats" format: embeddings is directly + // List> + if (this.embeddings instanceof List embeddingsList) { + return embeddingsList.stream().map(embedding -> { + if (embedding instanceof List embeddingVector) { + float[] floatArray = new float[embeddingVector.size()]; + for (int i = 0; i < embeddingVector.size(); i++) { + Object value = embeddingVector.get(i); + floatArray[i] = (value instanceof Number number) ? number.floatValue() : 0f; + } + return floatArray; + } + return new float[0]; + }).toList(); + } + + // Handle "embeddings_by_type" format: embeddings is Map + if (this.embeddings instanceof Map embeddingsMap) { + Object floatEmbeddings = embeddingsMap.get("float"); + if (floatEmbeddings instanceof List embeddingsList) { + return embeddingsList.stream().map(embedding -> { + if (embedding instanceof List embeddingVector) { + float[] floatArray = new float[embeddingVector.size()]; + for (int i = 0; i < embeddingVector.size(); i++) { + Object value = embeddingVector.get(i); + floatArray[i] = (value instanceof Number number) ? number.floatValue() : 0f; + } + return floatArray; + } + return new float[0]; + }).toList(); + } + } + + return List.of(); + } + + } + + public enum EventType { + + MESSAGE_END("message-end"), CONTENT_START("content-start"), CONTENT_DELTA("content-delta"), + CONTENT_END("content-end"), TOOL_PLAN_DELTA("tool-plan-delta"), TOOL_CALL_START("tool-call-start"), + TOOL_CALL_DELTA("tool-call-delta"), CITATION_START("citation-start"); + + public final String value; + + EventType(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + + } + +} diff --git a/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/api/CohereStreamFunctionCallingHelper.java b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/api/CohereStreamFunctionCallingHelper.java new file mode 100644 index 00000000000..63a961ac972 --- /dev/null +++ b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/api/CohereStreamFunctionCallingHelper.java @@ -0,0 +1,296 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.api; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionChunk; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage.ChatCompletionFunction; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage.Role; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage.ToolCall; +import org.springframework.ai.cohere.api.CohereApi.EventType; +import org.springframework.util.ObjectUtils; + +/** + * Helper class for handling streaming function calling in Cohere API. + * + * @author Ricken Bazolo + */ +public class CohereStreamFunctionCallingHelper { + + /** + * Merge the previous and current ChatCompletionChunk into a single one. + * @param previous the previous ChatCompletionChunk + * @param current the current ChatCompletionChunk + * @return the merged ChatCompletionChunk + */ + public ChatCompletionChunk merge(ChatCompletionChunk previous, ChatCompletionChunk current) { + + if (previous == null) { + return current; + } + + if (current == null) { + return previous; + } + + var previousDelta = previous.delta(); + var currentDelta = current.delta(); + + ChatCompletionMessage previousMessage = previousDelta != null ? previousDelta.message() : null; + ChatCompletionMessage currentMessage = currentDelta != null ? currentDelta.message() : null; + + Role role = previousMessage != null && previousMessage.role() != null ? previousMessage.role() + : (currentMessage != null ? currentMessage.role() : null); + + String previousText = previousMessage != null ? extractTextFromRawContent(previousMessage.rawContent()) : ""; + + String currentText = currentMessage != null ? extractTextFromRawContent(currentMessage.rawContent()) : ""; + + String currentType = current.type(); + String mergedText; + if (EventType.CONTENT_START.getValue().equals(currentType)) { + mergedText = currentText; + } + else if (EventType.CONTENT_END.getValue().equals(currentType)) { + mergedText = previousText; + } + else { + mergedText = previousText + currentText; + } + + String previousPlan = previousMessage != null ? previousMessage.toolPlan() : null; + String currentPlan = currentMessage != null ? currentMessage.toolPlan() : null; + + String mergedToolPlan = previousPlan; + + if (EventType.TOOL_PLAN_DELTA.getValue().equals(current.type())) { + mergedToolPlan = mergeToolPlan(previousPlan, currentPlan); + } + + List mergedToolCalls = mergeToolCalls(previous, current); + + List citations = mergeCitations(previous, current); + + ChatCompletionMessage mergedMessage = new ChatCompletionMessage(mergedText, role, mergedToolPlan, + mergedToolCalls, citations, null); + + var finishReason = (currentDelta != null && currentDelta.finishReason() != null) ? currentDelta.finishReason() + : (previousDelta != null ? previousDelta.finishReason() : null); + + var usage = (currentDelta != null && currentDelta.usage() != null) ? currentDelta.usage() + : (previousDelta != null ? previousDelta.usage() : null); + + var mergedDelta = new ChatCompletionChunk.ChunkDelta(mergedMessage, finishReason, usage); + + String id = current.id() != null ? current.id() : previous.id(); + String type = current.type() != null ? current.type() : previous.type(); + Integer index = current.index() != null ? current.index() : previous.index(); + + return new ChatCompletionChunk(id, type, index, mergedDelta); + } + + public ChatCompletionChunk sanitizeToolCalls(ChatCompletionChunk chunk) { + if (chunk == null || chunk.delta() == null || chunk.delta().message() == null) { + return chunk; + } + + ChatCompletionMessage msg = chunk.delta().message(); + List toolCalls = msg.toolCalls(); + + if (toolCalls == null || toolCalls.isEmpty()) { + return chunk; + } + + List cleaned = toolCalls.stream().filter(this::isValidToolCall).toList(); + + ChatCompletionChunk.ChunkDelta oldDelta = chunk.delta(); + + ChatCompletionMessage cleanedMsg = new ChatCompletionMessage(msg.rawContent(), msg.role(), msg.toolPlan(), + cleaned.isEmpty() ? null : cleaned, msg.citations(), null); + + ChatCompletionChunk.ChunkDelta newDelta = new ChatCompletionChunk.ChunkDelta(cleanedMsg, + oldDelta.finishReason(), oldDelta.usage()); + + return new ChatCompletionChunk(chunk.id(), chunk.type(), chunk.index(), newDelta); + } + + public boolean hasValidToolCallsOnly(ChatCompletionChunk c) { + if (c == null || c.delta() == null || c.delta().message() == null) { + return false; + } + + ChatCompletionMessage message = c.delta().message(); + List calls = message.toolCalls(); + + boolean hasValidToolCalls = calls != null && calls.stream().anyMatch(this::isValidToolCall); + + boolean hasTextContent = message.rawContent() != null + && !extractTextFromRawContent(message.rawContent()).isEmpty(); + + boolean hasCitations = message.citations() != null && !message.citations().isEmpty(); + + return hasValidToolCalls || hasTextContent || hasCitations; + } + + private boolean isValidToolCall(ToolCall toolCall) { + if (toolCall == null || toolCall.function() == null) { + return false; + } + ChatCompletionFunction chatCompletionFunction = toolCall.function(); + String functionName = chatCompletionFunction.name(); + String functionArguments = chatCompletionFunction.arguments(); + return !ObjectUtils.isEmpty(functionName) && !ObjectUtils.isEmpty(functionArguments); + } + + private String extractTextFromRawContent(Object rawContent) { + if (rawContent == null) { + return ""; + } + if (rawContent instanceof Map map) { + Object text = map.get("text"); + if (text != null) { + return text.toString(); + } + } + if (rawContent instanceof List list) { + StringBuilder sb = new StringBuilder(); + for (Object item : list) { + if (item instanceof Map m) { + Object text = m.get("text"); + if (text != null) { + sb.append(text); + } + } + else if (item instanceof String s) { + sb.append(s); + } + } + return sb.toString(); + } + if (rawContent instanceof String s) { + return s; + } + return rawContent.toString(); + } + + private List mergeToolCalls(ChatCompletionChunk previous, ChatCompletionChunk current) { + ChatCompletionMessage previousMessage = previous != null && previous.delta() != null + ? previous.delta().message() : null; + ChatCompletionMessage currentMessage = current.delta() != null ? current.delta().message() : null; + + List merged = ensureToolCallList(previousMessage != null ? previousMessage.toolCalls() : null); + + String type = current.type(); + Integer index = current.index(); + + if (index == null) { + return merged; + } + + ToolCall existing = ensureToolCallAtIndex(merged, index); + ChatCompletionFunction existingFunction = existing.function() != null ? existing.function() + : new ChatCompletionFunction(null, ""); + + String id = existing.id(); + String callType = existing.type(); + String functionName = existingFunction.name(); + String args = existingFunction.arguments() != null ? existingFunction.arguments() : ""; + + if (EventType.TOOL_CALL_START.getValue().equals(type) && currentMessage != null + && currentMessage.toolCalls() != null && !currentMessage.toolCalls().isEmpty()) { + + ToolCall start = currentMessage.toolCalls().get(0); + ChatCompletionFunction startFunction = start.function() != null ? start.function() + : new ChatCompletionFunction(null, ""); + + id = start.id() != null ? start.id() : id; + callType = start.type() != null ? start.type() : callType; + functionName = startFunction.name() != null ? startFunction.name() : functionName; + + } + + if (EventType.TOOL_CALL_DELTA.getValue().equals(type) && currentMessage != null + && currentMessage.toolCalls() != null && !currentMessage.toolCalls().isEmpty()) { + + ToolCall deltaCall = currentMessage.toolCalls().get(0); + ChatCompletionFunction deltaFunction = deltaCall.function(); + if (deltaFunction != null && deltaFunction.arguments() != null) { + args = (args == null ? "" : args) + deltaFunction.arguments(); + } + } + + // tool-call-end + ChatCompletionFunction mergedFn = new ChatCompletionFunction(functionName, args); + ToolCall mergedCall = new ToolCall(id, callType, mergedFn, index); + merged.set(index, mergedCall); + + return merged; + } + + private String mergeToolPlan(final String previous, final String currentFragment) { + if (currentFragment == null || currentFragment.isEmpty()) { + return previous; + } + if (previous == null) { + return currentFragment; + } + return previous + currentFragment; + } + + private List mergeCitations(final ChatCompletionChunk previous, + final ChatCompletionChunk current) { + + ChatCompletionMessage previousMessage = previous != null && previous.delta() != null + ? previous.delta().message() : null; + ChatCompletionMessage currentMessage = current != null && current.delta() != null ? current.delta().message() + : null; + + List merged = new ArrayList<>(); + + if (previousMessage != null && previousMessage.citations() != null) { + merged.addAll(previousMessage.citations()); + } + + if (current != null && EventType.CITATION_START.getValue().equals(current.type()) && currentMessage != null + && currentMessage.citations() != null) { + merged.addAll(currentMessage.citations()); + } + + return merged.isEmpty() ? null : merged; + } + + private List ensureToolCallList(final List toolCalls) { + return (toolCalls != null) ? new ArrayList<>(toolCalls) : new ArrayList<>(); + } + + private ToolCall ensureToolCallAtIndex(final List toolCalls, final int index) { + while (toolCalls.size() <= index) { + toolCalls.add(new ToolCall(null, null, new ChatCompletionFunction("", ""), index)); + } + ToolCall call = toolCalls.get(index); + if (call == null) { + call = new ToolCall(null, null, new ChatCompletionFunction("", ""), index); + toolCalls.set(index, call); + } + return call; + } + +} diff --git a/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/chat/CohereChatModel.java b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/chat/CohereChatModel.java new file mode 100644 index 00000000000..0e8e8382333 --- /dev/null +++ b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/chat/CohereChatModel.java @@ -0,0 +1,689 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.chat; + +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.DefaultUsage; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.MessageAggregator; +import org.springframework.ai.chat.observation.ChatModelObservationContext; +import org.springframework.ai.chat.observation.ChatModelObservationConvention; +import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; +import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.cohere.api.CohereApi; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletion; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage.Role; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage.ToolCall; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionRequest; +import org.springframework.ai.cohere.api.CohereApi.FunctionTool; +import org.springframework.ai.content.Media; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; +import org.springframework.ai.model.tool.ToolExecutionResult; +import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.support.UsageCalculator; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.http.ResponseEntity; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.MimeType; + +/** + * Represents a Cohere Chat Model. + * + * @author Ricken Bazolo + */ +public class CohereChatModel implements ChatModel { + + private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); + + private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = ToolCallingManager.builder().build(); + + private final Logger logger = LoggerFactory.getLogger(getClass()); + + /** + * The default options used for the chat completion requests. + */ + private final CohereChatOptions defaultOptions; + + /** + * Low-level access to the Cohere API. + */ + private final CohereApi cohereApi; + + private final RetryTemplate retryTemplate; + + /** + * Observation registry used for instrumentation. + */ + private final ObservationRegistry observationRegistry; + + private final ToolCallingManager toolCallingManager; + + /** + * The tool execution eligibility predicate used to determine if a tool can be + * executed. + */ + private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate; + + /** + * Conventions to use for generating observations. + */ + private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + + public CohereChatModel(CohereApi cohereApi, CohereChatOptions defaultOptions, ToolCallingManager toolCallingManager, + RetryTemplate retryTemplate, ObservationRegistry observationRegistry) { + this(cohereApi, defaultOptions, toolCallingManager, retryTemplate, observationRegistry, + new DefaultToolExecutionEligibilityPredicate()); + } + + public CohereChatModel(CohereApi cohereApi, CohereChatOptions defaultOptions, ToolCallingManager toolCallingManager, + RetryTemplate retryTemplate, ObservationRegistry observationRegistry, + ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { + Assert.notNull(cohereApi, "cohereApi cannot be null"); + Assert.notNull(defaultOptions, "defaultOptions cannot be null"); + Assert.notNull(toolCallingManager, "toolCallingManager cannot be null"); + Assert.notNull(retryTemplate, "retryTemplate cannot be null"); + Assert.notNull(observationRegistry, "observationRegistry cannot be null"); + Assert.notNull(toolExecutionEligibilityPredicate, "toolExecutionEligibilityPredicate cannot be null"); + this.cohereApi = cohereApi; + this.defaultOptions = defaultOptions; + this.toolCallingManager = toolCallingManager; + this.retryTemplate = retryTemplate; + this.observationRegistry = observationRegistry; + this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; + } + + public static ChatResponseMetadata from(ChatCompletion result) { + Assert.notNull(result, "Cohere ChatCompletion must not be null"); + DefaultUsage usage = getDefaultUsage(result.usage()); + return ChatResponseMetadata.builder().id(result.id()).usage(usage).build(); + } + + public static ChatResponseMetadata from(ChatCompletion result, Usage usage) { + Assert.notNull(result, "Cohere ChatCompletion must not be null"); + return ChatResponseMetadata.builder().id(result.id()).usage(usage).build(); + } + + private static DefaultUsage getDefaultUsage(CohereApi.Usage usage) { + return new DefaultUsage(usage.tokens().inputTokens(), usage.tokens().outputTokens(), null, usage); + } + + @Override + public ChatResponse call(Prompt prompt) { + Prompt requestPrompt = buildRequestPrompt(prompt); + return this.internalCall(requestPrompt, null); + } + + @Override + public Flux stream(Prompt prompt) { + Prompt requestPrompt = buildRequestPrompt(prompt); + return this.internalStream(requestPrompt, null); + } + + public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { + return Flux.deferContextual(contextView -> { + var request = createRequest(prompt, true); + + ChatModelObservationContext observationContext = ChatModelObservationContext.builder() + .prompt(prompt) + .provider(CohereApi.PROVIDER_NAME) + .build(); + + Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry); + + observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); + + Flux completionChunks = RetryUtils.execute(this.retryTemplate, + () -> this.cohereApi.chatCompletionStream(request)); + + // For chunked responses, only the first chunk contains the role. + // The rest of the chunks with same ID share the same role. + ConcurrentHashMap roleMap = new ConcurrentHashMap<>(); + + // Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse + // the function call handling logic. + Flux chatResponse = completionChunks.map(this::toChatCompletion) + .filter(chatCompletion -> chatCompletion != null && chatCompletion.message() != null) + .switchMap(chatCompletion -> Mono.just(chatCompletion).map(completion -> { + try { + @SuppressWarnings("null") + String id = completion.id(); + ChatCompletionMessage.Provider message = completion.message(); + + // Store the role for this completion ID + if (message.role() != null && id != null) { + roleMap.putIfAbsent(id, message.role().name()); + } + + List generations = message.content().stream().map(content -> { + Map metadata = Map.of("id", completion.id() != null ? completion.id() : "", + "role", completion.id() != null ? roleMap.getOrDefault(id, "") : "", "finishReason", + completion.finishReason() != null ? completion.finishReason().name() : ""); + return buildGeneration(content, completion, metadata); + }).toList(); + + if (completion.usage() != null) { + DefaultUsage usage = getDefaultUsage(completion.usage()); + Usage cumulativeUsage = UsageCalculator.getCumulativeUsage(usage, previousChatResponse); + return new ChatResponse(generations, from(completion, cumulativeUsage)); + } + else { + return new ChatResponse(generations); + } + } + catch (Exception e) { + logger.error("Error processing chat completion", e); + return new ChatResponse(List.of()); + } + })); + + // @formatter:off + Flux chatResponseFlux = chatResponse.flatMap(response -> { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + return Flux.deferContextual(ctx -> { + ToolExecutionResult toolExecutionResult; + try { + ToolCallReactiveContextHolder.setContext(ctx); + toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + } + finally { + ToolCallReactiveContextHolder.clearContext(); + } + if (toolExecutionResult.returnDirect()) { + // Return tool execution result directly to the client. + return Flux.just(ChatResponse.builder().from(response) + .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) + .build()); + } + else { + // Send the tool execution result back to the model. + var chatOptions = CohereChatOptions.fromOptions(prompt.getOptions().copy()); + return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), chatOptions), + response); + } + }).subscribeOn(Schedulers.boundedElastic()); + } + else { + return Flux.just(response); + } + }) + .doOnError(observation::error) + .doFinally(s -> observation.stop()) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + // @formatter:on; + + return new MessageAggregator().aggregate(chatResponseFlux, observationContext::setResponse); + }); + + } + + private ChatCompletion toChatCompletion(CohereApi.ChatCompletionChunk chunk) { + if (chunk == null || chunk.delta() == null) { + return null; + } + + CohereApi.ChatCompletionChunk.ChunkDelta delta = chunk.delta(); + ChatCompletionMessage message = delta.message(); + + ChatCompletionMessage.Provider provider = null; + if (message != null) { + + List content = extractMessageContent(message.rawContent()); + + provider = new ChatCompletionMessage.Provider(content, message.role(), message.toolPlan(), + message.toolCalls(), message.citations()); + } + + return new ChatCompletion(chunk.id(), delta.finishReason(), provider, null, delta.usage()); + } + + private List extractMessageContent(Object rawContent) { + if (rawContent == null) { + return List.of(); + } + + if (rawContent instanceof String text) { + return List.of(new ChatCompletionMessage.MessageContent("text", text, null)); + } + + if (rawContent instanceof List list) { + List messageContents = new ArrayList<>(); + for (Object item : list) { + if (item instanceof ChatCompletionMessage.MessageContent mc) { + messageContents.add(mc); + } + else if (item instanceof Map map) { + String type = (String) map.get("type"); + String text = (String) map.get("text"); + Object value = map.get("value"); + messageContents.add(new ChatCompletionMessage.MessageContent(type, text, value)); + } + } + return messageContents; + } + + if (rawContent instanceof Map map) { + String type = (String) map.get("type"); + String text = (String) map.get("text"); + Object value = map.get("value"); + return List.of(new ChatCompletionMessage.MessageContent(type != null ? type : "text", text, value)); + } + + return List.of(); + } + + Prompt buildRequestPrompt(Prompt prompt) { + // Process runtime options + CohereChatOptions runtimeOptions = null; + if (prompt.getOptions() != null) { + if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { + runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class, + CohereChatOptions.class); + } + else { + runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, + CohereChatOptions.class); + } + } + + // Define request options by merging runtime options and default options + CohereChatOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, + CohereChatOptions.class); + + // Merge @JsonIgnore-annotated options explicitly since they are ignored by + // Jackson, used by ModelOptionsUtils. + if (runtimeOptions != null) { + requestOptions.setInternalToolExecutionEnabled( + ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(), + this.defaultOptions.getInternalToolExecutionEnabled())); + requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), + this.defaultOptions.getToolNames())); + requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), + this.defaultOptions.getToolCallbacks())); + requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(), + this.defaultOptions.getToolContext())); + } + else { + requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); + requestOptions.setToolNames(this.defaultOptions.getToolNames()); + requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); + requestOptions.setToolContext(this.defaultOptions.getToolContext()); + } + + ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks()); + + return new Prompt(prompt.getInstructions(), requestOptions); + } + + private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { + ChatCompletionRequest request = createRequest(prompt, false); + + ChatModelObservationContext observationContext = ChatModelObservationContext.builder() + .prompt(prompt) + .provider(CohereApi.PROVIDER_NAME) + .build(); + + ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION + .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry) + .observe(() -> { + + ResponseEntity completionEntity = RetryUtils.execute(this.retryTemplate, + () -> this.cohereApi.chatCompletionEntity(request)); + + ChatCompletion chatCompletion = completionEntity.getBody(); + + if (chatCompletion == null) { + logger.warn("No chat completion returned for prompt: {}", prompt); + return new ChatResponse(List.of()); + } + + final Map metadata = Map.of("id", + chatCompletion.id() != null ? chatCompletion.id() : "", "role", + chatCompletion.message().role() != null ? chatCompletion.message().role().name() : "", + "finishReason", + chatCompletion.finishReason() != null ? chatCompletion.finishReason().name() : ""); + + List generations = new ArrayList<>(); + + if (chatCompletion.finishReason() == null) { // Just for secure + logger.warn("No chat completion finishReason returned for prompt: {}", prompt); + return new ChatResponse(List.of()); + } + + if (chatCompletion.finishReason().equals(CohereApi.ChatCompletionFinishReason.TOOL_CALL)) { + var generation = buildGeneration(null, chatCompletion, metadata); + generations.add(generation); + } + else { + generations = chatCompletion.message() + .content() + .stream() + .map(content -> buildGeneration(content, chatCompletion, metadata)) + .toList(); + } + + DefaultUsage usage = getDefaultUsage(completionEntity.getBody().usage()); + Usage cumulativeUsage = UsageCalculator.getCumulativeUsage(usage, previousChatResponse); + + ChatResponse chatResponse = new ChatResponse(generations, + from(completionEntity.getBody(), cumulativeUsage)); + + observationContext.setResponse(chatResponse); + + return chatResponse; + }); + + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + if (toolExecutionResult.returnDirect()) { + // Return tool execution result directly to the client. + return ChatResponse.builder() + .from(response) + .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) + .build(); + } + else { + // remove tools actions before + ChatOptions chatOptions = CohereChatOptions.fromOptions2(prompt.getOptions().copy()); + // Send the tool execution result back to the model. + return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), chatOptions), null); + } + } + + return response; + } + + private Generation buildGeneration(ChatCompletionMessage.MessageContent content, ChatCompletion completion, + Map metadata) { + List toolCalls = completion.message().toolCalls() == null ? List.of() + : completion.message() + .toolCalls() + .stream() + .map(toolCall -> new AssistantMessage.ToolCall(toolCall.id(), "function", + toolCall.function().name(), toolCall.function().arguments())) + .toList(); + + var assistantMessage = AssistantMessage.builder() + .content(content != null ? content.text() : "") + .toolCalls(toolCalls) + .properties(metadata) + .build(); + + String finishReason = (completion.finishReason() != null ? completion.finishReason().name() : ""); + var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build(); + return new Generation(assistantMessage, generationMetadata); + } + + /** + * Accessible for testing. + */ + ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { + List chatCompletionMessages = prompt.getInstructions() + .stream() + .map(this::convertToCohereMessage) + .flatMap(List::stream) + .toList(); + + var request = new ChatCompletionRequest(chatCompletionMessages, stream); + + CohereChatOptions requestOptions = (CohereChatOptions) prompt.getOptions(); + request = ModelOptionsUtils.merge(requestOptions, request, ChatCompletionRequest.class); + + // Add the tool definitions to the request's tools parameter. + List toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions); + if (!CollectionUtils.isEmpty(toolDefinitions)) { + request = ModelOptionsUtils.merge( + CohereChatOptions.builder().tools(this.getFunctionTools(toolDefinitions)).build(), request, + ChatCompletionRequest.class); + } + + return request; + } + + /** + * Convert a Spring AI message to Cohere API message(s). + * @param message the Spring AI message to convert + * @return list of Cohere ChatCompletionMessage(s) + */ + private List convertToCohereMessage(Message message) { + return switch (message.getMessageType()) { + case USER -> convertUserMessage(message); + case ASSISTANT -> convertAssistantMessage(message); + case SYSTEM -> convertSystemMessage(message); + case TOOL -> convertToolMessage(message); + }; + } + + /** + * Convert a USER message. + */ + private List convertUserMessage(org.springframework.ai.chat.messages.Message message) { + Object content = message.getText(); + + if (message instanceof UserMessage userMessage && !CollectionUtils.isEmpty(userMessage.getMedia())) { + // Validate images before processing + CohereImageValidator.validateImages(userMessage.getMedia()); + + List contentList = new ArrayList<>( + List.of(new ChatCompletionMessage.MediaContent(message.getText()))); + + contentList.addAll(userMessage.getMedia().stream().map(this::mapToMediaContent).toList()); + + content = contentList; + } + + return List.of(new ChatCompletionMessage(content, Role.USER)); + } + + /** + * Convert an ASSISTANT message. + */ + private List convertAssistantMessage(org.springframework.ai.chat.messages.Message message) { + if (!(message instanceof AssistantMessage assistantMessage)) { + throw new IllegalArgumentException("Unsupported assistant message class: " + message.getClass().getName()); + } + + List toolCalls = null; + if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { + toolCalls = convertToolCalls(assistantMessage.getToolCalls()); + } + + return List.of(new ChatCompletionMessage(assistantMessage.getText(), Role.ASSISTANT, toolCalls)); + } + + /** + * Convert tool calls. + */ + private List convertToolCalls(List springToolCalls) { + return springToolCalls.stream().map(toolCall -> { + var function = new ChatCompletionMessage.ChatCompletionFunction(toolCall.name(), toolCall.arguments()); + return new ToolCall(toolCall.id(), toolCall.type(), function, null); + }).toList(); + } + + /** + * Convert a SYSTEM message. + */ + private List convertSystemMessage(org.springframework.ai.chat.messages.Message message) { + return List.of(new ChatCompletionMessage(message.getText(), Role.SYSTEM)); + } + + /** + * Convert a TOOL response message to Cohere format. Validates that all tool responses + * have an ID. + */ + private List convertToolMessage(Message message) { + + if (!(message instanceof ToolResponseMessage toolResponseMessage)) { + throw new IllegalArgumentException("Unsupported tool message class: " + message.getClass().getName()); + } + + return toolResponseMessage.getResponses().stream().map(toolResponse -> { + Assert.notNull(toolResponse.id(), "ToolResponseMessage.ToolResponse must have an id"); + return new ChatCompletionMessage(toolResponse.responseData(), Role.TOOL, toolResponse.name(), null, null, + toolResponse.id()); + }).toList(); + } + + private ChatCompletionMessage.MediaContent mapToMediaContent(Media media) { + CohereApi.ChatCompletionMessage.MediaContent.DetailLevel detail = this.defaultOptions != null + ? this.defaultOptions.getImageDetail() : null; + return new ChatCompletionMessage.MediaContent(new ChatCompletionMessage.MediaContent.ImageUrl( + this.fromMediaData(media.getMimeType(), media.getData()), detail)); + } + + private String fromMediaData(MimeType mimeType, Object mediaContentData) { + if (mediaContentData instanceof byte[] bytes) { + // Assume the bytes are an image. So, convert the bytes to a base64 encoded + // following the prefix pattern. + return String.format("data:%s;base64,%s", mimeType.toString(), Base64.getEncoder().encodeToString(bytes)); + } + else if (mediaContentData instanceof String text) { + // Assume the text is a URLs or a base64 encoded image prefixed by the user. + return text; + } + else { + throw new IllegalArgumentException( + "Unsupported media data type: " + mediaContentData.getClass().getSimpleName()); + } + } + + private List getFunctionTools(List toolDefinitions) { + return toolDefinitions.stream().map(toolDefinition -> { + var function = new FunctionTool.Function(toolDefinition.description(), toolDefinition.name(), + toolDefinition.inputSchema()); + return new FunctionTool(function); + }).toList(); + } + + @Override + public ChatOptions getDefaultOptions() { + return CohereChatOptions.fromOptions(this.defaultOptions); + } + + /** + * Use the provided convention for reporting observation data + * @param observationConvention The provided convention + */ + public void setObservationConvention(ChatModelObservationConvention observationConvention) { + Assert.notNull(observationConvention, "observationConvention cannot be null"); + this.observationConvention = observationConvention; + } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + + private CohereApi cohereApi; + + private CohereChatOptions defaultOptions = CohereChatOptions.builder() + .temperature(0.3) + .topP(1.0) + .model(CohereApi.ChatModel.COMMAND_A_R7B.getValue()) + .build(); + + private ToolCallingManager toolCallingManager; + + private ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate = new DefaultToolExecutionEligibilityPredicate(); + + private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + + private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; + + private Builder() { + } + + public Builder cohereApi(CohereApi cohereApi) { + this.cohereApi = cohereApi; + return this; + } + + public Builder defaultOptions(CohereChatOptions defaultOptions) { + this.defaultOptions = defaultOptions; + return this; + } + + public Builder toolCallingManager(ToolCallingManager toolCallingManager) { + this.toolCallingManager = toolCallingManager; + return this; + } + + public Builder toolExecutionEligibilityPredicate( + ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { + this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; + return this; + } + + public Builder retryTemplate(RetryTemplate retryTemplate) { + this.retryTemplate = retryTemplate; + return this; + } + + public Builder observationRegistry(ObservationRegistry observationRegistry) { + this.observationRegistry = observationRegistry; + return this; + } + + public CohereChatModel build() { + if (this.toolCallingManager != null) { + return new CohereChatModel(this.cohereApi, this.defaultOptions, this.toolCallingManager, + this.retryTemplate, this.observationRegistry, this.toolExecutionEligibilityPredicate); + } + return new CohereChatModel(this.cohereApi, this.defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, + this.retryTemplate, this.observationRegistry, this.toolExecutionEligibilityPredicate); + } + + } + +} diff --git a/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/chat/CohereChatOptions.java b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/chat/CohereChatOptions.java new file mode 100644 index 00000000000..6f9955180cf --- /dev/null +++ b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/chat/CohereChatOptions.java @@ -0,0 +1,624 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.chat; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.springframework.ai.cohere.api.CohereApi; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage.MediaContent.DetailLevel; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionRequest.ResponseFormat; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionRequest.ToolChoice; +import org.springframework.ai.cohere.api.CohereApi.FunctionTool; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Options for the Cohere API. + * + * @author Ricken Bazolo + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class CohereChatOptions implements ToolCallingChatOptions { + + /** + * ID of the model to use + */ + private @JsonProperty("model") String model; + + /** + * What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.8 will + * make the output more random, while lower values like 0.2 will make it more focused + * and deterministic. We generally recommend altering this or top_p but not both. + */ + private @JsonProperty("temperature") Double temperature; + + /** + * Ensures that only the most likely tokens, with total probability mass of p, are + * considered for generation at each step. If both k and p are enabled, p acts after + * k. Defaults to 0.75. min value of 0.01, max value of 0.99. + */ + private @JsonProperty("p") Double p; + + /** + * The maximum number of tokens to generate in the chat completion. The total length + * of input tokens and generated tokens is limited by the model's context length. + */ + private @JsonProperty("max_tokens") Integer maxTokens; + + /** + * Min value of 0.0, max value of 1.0. Used to reduce repetitiveness of generated + * tokens. Similar to frequency_penalty, except that this penalty is applied equally + * to all tokens that have already appeared, regardless of their exact frequencies. + */ + private @JsonProperty("presence_penalty") Double presencePenalty; + + /** + * Min value of 0.0, max value of 1.0. Used to reduce repetitiveness of generated + * tokens. Similar to frequency_penalty, except that this penalty is applied equally + * to all tokens that have already appeared, regardless of their exact frequencies. + */ + private @JsonProperty("frequency_penalty") Double frequencyPenalty; + + /** + * Ensures that only the top k most likely tokens are considered for generation at + * each step. When k is set to 0, k-sampling is disabled. Defaults to 0, min value of + * 0, max value of 500. + */ + private @JsonProperty("k") Integer k; + + /** + * A list of tools the model may call. Currently, only functions are supported as a + * tool. Use this to provide a list of functions the model may generate JSON inputs + * for. + */ + private @JsonProperty("tools") List tools; + + /** + * An object specifying the format that the model must output. Setting to { "type": + * "json_object" } enables JSON mode, which guarantees the message the model generates + * is valid JSON. + */ + private @JsonProperty("response_format") ResponseFormat responseFormat; + + /** + * Used to select the safety instruction inserted into the prompt. Defaults to + * CONTEXTUAL. When OFF is specified, the safety instruction will be omitted. + */ + private @JsonProperty("safety_mode") CohereApi.SafetyMode safetyMode; + + /** + * A list of up to 5 strings that the model will use to stop generation. If the model + * generates a string that matches any of the strings in the list, it will stop + * generating tokens and return the generated text up to that point not including the + * stop sequence. + */ + private @JsonProperty("stop_sequences") List stopSequences; + + /** + * If specified, the backend will make a best effort to sample tokens + * deterministically, such that repeated requests with the same seed and parameters + * should return the same result. However, determinism cannot be totally guaranteed. + */ + private @JsonProperty("seed") Integer seed; + + /** + * Defaults to false. When set to true, the log probabilities of the generated tokens + * will be included in the response. + */ + private @JsonProperty("logprobs") Boolean logprobs; + + /** + * Controls which (if any) function is called by the model. none means the model will + * not call a function and instead generates a message. auto means the model can pick + * between generating a message or calling a function. Specifying a particular + * function via {"type: "function", "function": {"name": "my_function"}} forces the + * model to call that function. none is the default when no functions are present. + * auto is the default if functions are present. Use the + * {@link CohereApi.ToolChoiceBuilder} to create a tool choice object. + */ + private @JsonProperty("tool_choice") ToolChoice toolChoice; + + private @JsonProperty("strict_tools") Boolean strictTools; + + /** + * The level of detail for processing images. Can be "low", "high", or "auto". + * Defaults to "auto" if not specified. This controls the resolution at which the + * model views image. + */ + @JsonIgnore + private DetailLevel imageDetail; + + /** + * Collection of {@link ToolCallback}s to be used for tool calling in the chat + * completion requests. + */ + @JsonIgnore + private List toolCallbacks = new ArrayList<>(); + + /** + * Collection of tool names to be resolved at runtime and used for tool calling in the + * chat completion requests. + */ + @JsonIgnore + private Set toolNames = new HashSet<>(); + + /** + * Whether to enable the tool execution lifecycle internally in ChatModel. + */ + @JsonIgnore + private Boolean internalToolExecutionEnabled; + + @JsonIgnore + private Map toolContext = new HashMap<>(); + + public CohereApi.SafetyMode getSafetyMode() { + return this.safetyMode; + } + + public void setSafetyMode(CohereApi.SafetyMode safetyMode) { + this.safetyMode = safetyMode; + } + + public Integer getSeed() { + return this.seed; + } + + public void setSeed(Integer seed) { + this.seed = seed; + } + + public Boolean getLogprobs() { + return this.logprobs; + } + + public void setLogprobs(Boolean logprobs) { + this.logprobs = logprobs; + } + + public Boolean getStrictTools() { + return this.strictTools; + } + + public void setStrictTools(Boolean strictTools) { + this.strictTools = strictTools; + } + + public DetailLevel getImageDetail() { + return this.imageDetail; + } + + public void setImageDetail(DetailLevel imageDetail) { + this.imageDetail = imageDetail; + } + + public Double getP() { + return this.p; + } + + public void setP(Double p) { + this.p = p; + } + + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public Integer getMaxTokens() { + return this.maxTokens; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + public ResponseFormat getResponseFormat() { + return this.responseFormat; + } + + public void setResponseFormat(ResponseFormat responseFormat) { + this.responseFormat = responseFormat; + } + + @Override + @JsonIgnore + public List getStopSequences() { + return getStop(); + } + + @JsonIgnore + public void setStopSequences(List stopSequences) { + setStop(stopSequences); + } + + public List getStop() { + return this.stopSequences; + } + + public void setStop(List stop) { + this.stopSequences = stop; + } + + public List getTools() { + return this.tools; + } + + public void setTools(List tools) { + this.tools = tools; + } + + public ToolChoice getToolChoice() { + return this.toolChoice; + } + + public void setToolChoice(ToolChoice toolChoice) { + this.toolChoice = toolChoice; + } + + @Override + public Double getTemperature() { + return this.temperature; + } + + public void setTemperature(Double temperature) { + this.temperature = temperature; + } + + @Override + public Double getTopP() { + return getP(); + } + + public void setTopP(Double topP) { + setP(topP); + } + + @Override + public Double getFrequencyPenalty() { + return this.frequencyPenalty; + } + + public void setFrequencyPenalty(Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + @Override + public Double getPresencePenalty() { + return this.presencePenalty; + } + + public void setPresencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + @Override + public CohereChatOptions copy() { + return fromOptions(this); + } + + @Override + @JsonIgnore + public List getToolCallbacks() { + return this.toolCallbacks; + } + + @Override + @JsonIgnore + public void setToolCallbacks(List toolCallbacks) { + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); + this.toolCallbacks = toolCallbacks; + } + + @Override + @JsonIgnore + public Set getToolNames() { + return this.toolNames; + } + + @Override + @JsonIgnore + public void setToolNames(Set toolNames) { + Assert.notNull(toolNames, "toolNames cannot be null"); + Assert.noNullElements(toolNames, "toolNames cannot contain null elements"); + toolNames.forEach(tool -> Assert.hasText(tool, "toolNames cannot contain empty elements")); + this.toolNames = toolNames; + } + + @Override + @Nullable + @JsonIgnore + public Boolean getInternalToolExecutionEnabled() { + return this.internalToolExecutionEnabled; + } + + @Override + @JsonIgnore + public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) { + this.internalToolExecutionEnabled = internalToolExecutionEnabled; + } + + @Override + @JsonIgnore + public Integer getTopK() { + return this.k; + } + + public void setTopK(Integer k) { + this.k = k; + } + + @Override + @JsonIgnore + public Map getToolContext() { + return this.toolContext; + } + + @Override + @JsonIgnore + public void setToolContext(Map toolContext) { + this.toolContext = toolContext; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + CohereChatOptions that = (CohereChatOptions) o; + return Objects.equals(this.model, that.model) && Objects.equals(this.temperature, that.temperature) + && Objects.equals(this.p, that.p) && Objects.equals(this.maxTokens, that.maxTokens) + && Objects.equals(this.presencePenalty, that.presencePenalty) + && Objects.equals(this.frequencyPenalty, that.frequencyPenalty) && Objects.equals(this.k, that.k) + && Objects.equals(this.tools, that.tools) && Objects.equals(this.responseFormat, that.responseFormat) + && Objects.equals(this.safetyMode, that.safetyMode) + && Objects.equals(this.stopSequences, that.stopSequences) && Objects.equals(this.seed, that.seed) + && Objects.equals(this.logprobs, that.logprobs) && Objects.equals(this.toolChoice, that.toolChoice) + && Objects.equals(this.strictTools, that.strictTools) + && Objects.equals(this.imageDetail, that.imageDetail) + && Objects.equals(this.toolCallbacks, that.toolCallbacks) + && Objects.equals(this.toolNames, that.toolNames) + && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) + && Objects.equals(this.toolContext, that.toolContext); + } + + @Override + public int hashCode() { + return Objects.hash(this.model, this.temperature, this.p, this.maxTokens, this.presencePenalty, + this.frequencyPenalty, this.k, this.tools, this.responseFormat, this.safetyMode, this.stopSequences, + this.seed, this.logprobs, this.toolChoice, this.strictTools, this.imageDetail, this.toolCallbacks, + this.toolNames, this.internalToolExecutionEnabled, this.toolContext); + } + + public static Builder builder() { + return new Builder(); + } + + public static CohereChatOptions fromOptions(CohereChatOptions fromOptions) { + Builder builder = builder().model(fromOptions.getModel()) + .temperature(fromOptions.getTemperature()) + .maxTokens(fromOptions.getMaxTokens()) + .topP(fromOptions.getTopP()) + .frequencyPenalty(fromOptions.getFrequencyPenalty()) + .presencePenalty(fromOptions.getPresencePenalty()) + .topK(fromOptions.getTopK()) + .responseFormat(fromOptions.getResponseFormat()) + .safetyMode(fromOptions.getSafetyMode()) + .seed(fromOptions.getSeed()) + .logprobs(fromOptions.getLogprobs()) + .toolChoice(fromOptions.getToolChoice()) + .strictTools(fromOptions.getStrictTools()) + .imageDetail(fromOptions.getImageDetail()) + .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()); + + // Create defensive copies of collections + if (fromOptions.getTools() != null) { + builder.tools(new ArrayList<>(fromOptions.getTools())); + } + if (fromOptions.getStopSequences() != null) { + builder.stop(new ArrayList<>(fromOptions.getStopSequences())); + } + if (fromOptions.getToolCallbacks() != null) { + builder.toolCallbacks(new ArrayList<>(fromOptions.getToolCallbacks())); + } + if (fromOptions.getToolNames() != null) { + builder.toolNames(new HashSet<>(fromOptions.getToolNames())); + } + if (fromOptions.getToolContext() != null) { + builder.toolContext(new HashMap<>(fromOptions.getToolContext())); + } + + return builder.build(); + } + + public static CohereChatOptions fromOptions2(CohereChatOptions fromOptions) { + return builder().model(fromOptions.getModel()) + .temperature(fromOptions.getTemperature()) + .maxTokens(fromOptions.getMaxTokens()) + .topP(fromOptions.getTopP()) + .frequencyPenalty(fromOptions.getFrequencyPenalty()) + .presencePenalty(fromOptions.getPresencePenalty()) + .topK(fromOptions.getTopK()) + .tools(null) + .responseFormat(fromOptions.getResponseFormat()) + .safetyMode(fromOptions.getSafetyMode()) + .stop(fromOptions.getStopSequences()) + .seed(fromOptions.getSeed()) + .logprobs(fromOptions.getLogprobs()) + .toolChoice(null) + .strictTools(null) + .toolCallbacks() + .toolNames() + .internalToolExecutionEnabled(null) + .build(); + } + + public static class Builder { + + private final CohereChatOptions options = new CohereChatOptions(); + + public CohereChatOptions build() { + return this.options; + } + + public Builder model(String model) { + this.options.setModel(model); + return this; + } + + public Builder model(CohereApi.ChatModel chatModel) { + this.options.setModel(chatModel.getName()); + return this; + } + + public Builder safetyMode(CohereApi.SafetyMode safetyMode) { + this.options.setSafetyMode(safetyMode); + return this; + } + + public Builder logprobs(Boolean logprobs) { + this.options.setLogprobs(logprobs); + return this; + } + + public Builder toolContext(Map toolContext) { + if (this.options.toolContext == null) { + this.options.toolContext = toolContext; + } + else { + this.options.toolContext.putAll(toolContext); + } + return this; + } + + public Builder maxTokens(Integer maxTokens) { + this.options.setMaxTokens(maxTokens); + return this; + } + + public Builder seed(Integer seed) { + this.options.setSeed(seed); + return this; + } + + public Builder stop(List stop) { + this.options.setStop(stop); + return this; + } + + public Builder frequencyPenalty(Double frequencyPenalty) { + this.options.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder presencePenalty(Double presencePenalty) { + this.options.presencePenalty = presencePenalty; + return this; + } + + public Builder temperature(Double temperature) { + this.options.setTemperature(temperature); + return this; + } + + public Builder topP(Double topP) { + this.options.setTopP(topP); + return this; + } + + public Builder topK(Integer k) { + this.options.setTopK(k); + return this; + } + + public Builder responseFormat(ResponseFormat responseFormat) { + this.options.responseFormat = responseFormat; + return this; + } + + public Builder tools(List tools) { + this.options.tools = tools; + return this; + } + + public Builder strictTools(Boolean strictTools) { + this.options.setStrictTools(strictTools); + return this; + } + + public Builder toolChoice(ToolChoice toolChoice) { + this.options.toolChoice = toolChoice; + return this; + } + + public Builder toolCallbacks(List toolCallbacks) { + this.options.setToolCallbacks(toolCallbacks); + return this; + } + + public Builder toolCallbacks(ToolCallback... toolCallbacks) { + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + this.options.toolCallbacks.addAll(Arrays.asList(toolCallbacks)); + return this; + } + + public Builder toolNames(Set toolNames) { + Assert.notNull(toolNames, "toolNames cannot be null"); + this.options.setToolNames(toolNames); + return this; + } + + public Builder toolNames(String... toolNames) { + Assert.notNull(toolNames, "toolNames cannot be null"); + this.options.toolNames.addAll(Set.of(toolNames)); + return this; + } + + public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) { + this.options.setInternalToolExecutionEnabled(internalToolExecutionEnabled); + return this; + } + + public Builder imageDetail(DetailLevel imageDetail) { + this.options.setImageDetail(imageDetail); + return this; + } + + } + +} diff --git a/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/chat/CohereImageValidator.java b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/chat/CohereImageValidator.java new file mode 100644 index 00000000000..43ee99c300f --- /dev/null +++ b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/chat/CohereImageValidator.java @@ -0,0 +1,101 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.chat; + +import java.util.List; +import java.util.Set; + +import org.springframework.ai.content.Media; + +/** + * Validator for Cohere API image constraints. + * + * @author Ricken Bazolo + */ +public final class CohereImageValidator { + + private static final int MAX_IMAGES_PER_REQUEST = 20; + + private static final long MAX_TOTAL_IMAGE_SIZE_BYTES = 20 * 1024 * 1024; + + private static final Set SUPPORTED_IMAGE_FORMATS = Set.of("image/jpeg", "image/png", "image/webp", + "image/gif"); + + private CohereImageValidator() { + } + + public static void validateImages(List mediaList) { + if (mediaList == null || mediaList.isEmpty()) { + return; + } + + validateImageCount(mediaList); + validateImageFormats(mediaList); + validateTotalImageSize(mediaList); + } + + private static void validateImageCount(List mediaList) { + if (mediaList.size() > MAX_IMAGES_PER_REQUEST) { + throw new IllegalArgumentException( + String.format("Cohere API supports maximum %d images per request, found: %d", + MAX_IMAGES_PER_REQUEST, mediaList.size())); + } + } + + private static void validateImageFormats(List mediaList) { + for (Media media : mediaList) { + var mimeType = media.getMimeType().toString(); + if (!SUPPORTED_IMAGE_FORMATS.contains(mimeType)) { + throw new IllegalArgumentException(String + .format("Unsupported image format: %s. Supported formats: JPEG, PNG, WebP, GIF", mimeType)); + } + } + } + + private static void validateTotalImageSize(List mediaList) { + long totalSize = 0; + + for (Media media : mediaList) { + long mediaSize = calculateMediaSize(media); + totalSize += mediaSize; + } + + if (totalSize > MAX_TOTAL_IMAGE_SIZE_BYTES) { + long totalSizeMB = totalSize / (1024 * 1024); + throw new IllegalArgumentException(String.format("Total image size exceeds 20MB limit: %dMB", totalSizeMB)); + } + } + + private static long calculateMediaSize(Media media) { + var data = media.getData(); + + if (data instanceof byte[] bytes) { + return bytes.length; + } + + if (data instanceof String text) { + if (text.startsWith("data:")) { + var base64Data = text.substring(text.indexOf(",") + 1); + return (long) (base64Data.length() * 0.75); + } + return 0; + } + + return 0; + } + +} diff --git a/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereEmbeddingModel.java b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereEmbeddingModel.java new file mode 100644 index 00000000000..d0d526a0a0a --- /dev/null +++ b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereEmbeddingModel.java @@ -0,0 +1,249 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.embedding; + +import java.util.List; +import java.util.Map; + +import io.micrometer.observation.ObservationRegistry; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.cohere.api.CohereApi; +import org.springframework.ai.document.Document; +import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.embedding.AbstractEmbeddingModel; +import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResponseMetadata; +import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.util.Assert; + +/** + * Provides the Cohere Embedding Model. + * + * @author Ricken Bazolo + * @see AbstractEmbeddingModel + */ +public class CohereEmbeddingModel extends AbstractEmbeddingModel { + + private static final Logger logger = LoggerFactory.getLogger(CohereEmbeddingModel.class); + + /** + * Known embedding dimensions for Cohere models. Maps model names to their respective + * embedding vector dimensions. This allows the dimensions() method to return the + * correct value without making an API call. + */ + private static final Map KNOWN_EMBEDDING_DIMENSIONS = Map.of( + CohereApi.EmbeddingModel.EMBED_MULTILINGUAL_V3.getValue(), 1024, + CohereApi.EmbeddingModel.EMBED_ENGLISH_V3.getValue(), 1024, + CohereApi.EmbeddingModel.EMBED_MULTILINGUAL_LIGHT_V3.getValue(), 384, + CohereApi.EmbeddingModel.EMBED_ENGLISH_LIGHT_V3.getValue(), 384, + CohereApi.EmbeddingModel.EMBED_V4.getValue(), 1536); + + private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention(); + + private final CohereEmbeddingOptions defaultOptions; + + private final MetadataMode metadataMode; + + private final CohereApi cohereApi; + + private final RetryTemplate retryTemplate; + + /** + * Observation registry used for instrumentation. + */ + private final ObservationRegistry observationRegistry; + + /** + * Conventions to use for generating observations. + */ + private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + + public CohereEmbeddingModel(CohereApi cohereApi, MetadataMode metadataMode, CohereEmbeddingOptions options, + RetryTemplate retryTemplate, ObservationRegistry observationRegistry) { + Assert.notNull(cohereApi, "cohereApi must not be null"); + Assert.notNull(metadataMode, "metadataMode must not be null"); + Assert.notNull(options, "options must not be null"); + Assert.notNull(retryTemplate, "retryTemplate must not be null"); + Assert.notNull(observationRegistry, "observationRegistry must not be null"); + + this.cohereApi = cohereApi; + this.metadataMode = metadataMode; + this.defaultOptions = options; + this.retryTemplate = retryTemplate; + this.observationRegistry = observationRegistry; + } + + @Override + public EmbeddingResponse call(EmbeddingRequest request) { + + var apiRequest = createRequest(request); + + var observationContext = EmbeddingModelObservationContext.builder() + .embeddingRequest(request) + .provider(CohereApi.PROVIDER_NAME) + .build(); + + return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION + .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry) + .observe(() -> { + var apiEmbeddingResponse = RetryUtils.execute(this.retryTemplate, + () -> this.cohereApi.embeddings(apiRequest).getBody()); + + if (apiEmbeddingResponse == null) { + logger.warn("No embeddings returned for request: {}", request); + return new EmbeddingResponse(List.of()); + } + + var metadata = generateResponseMetadata(apiEmbeddingResponse.responseType()); + + // Extract float embeddings from response + List floatEmbeddings = apiEmbeddingResponse.getFloatEmbeddings(); + + // Map to Spring AI Embedding objects with proper indexing + List embeddings = new java.util.ArrayList<>(); + for (int i = 0; i < floatEmbeddings.size(); i++) { + embeddings.add(new Embedding(floatEmbeddings.get(i), i)); + } + + var embeddingResponse = new EmbeddingResponse(embeddings, metadata); + + observationContext.setResponse(embeddingResponse); + + return embeddingResponse; + }); + } + + @Override + public float[] embed(Document document) { + Assert.notNull(document, "Document must not be null"); + return this.embed(document.getFormattedContent(this.metadataMode)); + } + + private EmbeddingResponseMetadata generateResponseMetadata(String embeddingType) { + return new EmbeddingResponseMetadata(embeddingType, null); + } + + /** + * Use the provided convention for reporting observation data. + * @param observationConvention The provided convention + */ + public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) { + Assert.notNull(observationConvention, "observationConvention cannot be null"); + this.observationConvention = observationConvention; + } + + private CohereApi.EmbeddingRequest createRequest(EmbeddingRequest request) { + CohereEmbeddingOptions options = mergeOptions(request.getOptions(), this.defaultOptions); + + return CohereApi.EmbeddingRequest.builder() + .model(options.getModel()) + .inputType(options.getInputType()) + .embeddingTypes(options.getEmbeddingTypes()) + .texts(request.getInstructions()) + .truncate(options.getTruncate()) + .build(); + } + + private CohereEmbeddingOptions mergeOptions(EmbeddingOptions requestOptions, + CohereEmbeddingOptions defaultOptions) { + CohereEmbeddingOptions options = (requestOptions != null) + ? ModelOptionsUtils.merge(requestOptions, defaultOptions, CohereEmbeddingOptions.class) + : defaultOptions; + + if (options == null) { + throw new IllegalArgumentException("Embedding options must not be null"); + } + + return options; + } + + private CohereEmbeddingOptions buildRequestOptions(EmbeddingRequest request) { + return mergeOptions(request.getOptions(), this.defaultOptions); + } + + @Override + public int dimensions() { + String model = this.defaultOptions.getModel(); + if (model == null) { + return KNOWN_EMBEDDING_DIMENSIONS.get(CohereApi.EmbeddingModel.EMBED_V4.getValue()); + } + return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(model, 1024); + } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + + private CohereApi cohereApi; + + private MetadataMode metadataMode = MetadataMode.EMBED; + + private CohereEmbeddingOptions options = CohereEmbeddingOptions.builder() + .model(CohereApi.EmbeddingModel.EMBED_MULTILINGUAL_LIGHT_V3.getValue()) + .build(); + + private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + + private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; + + public Builder cohereApi(CohereApi cohereApi) { + this.cohereApi = cohereApi; + return this; + } + + public Builder metadataMode(MetadataMode metadataMode) { + this.metadataMode = metadataMode; + return this; + } + + public Builder options(CohereEmbeddingOptions options) { + this.options = options; + return this; + } + + public Builder retryTemplate(RetryTemplate retryTemplate) { + this.retryTemplate = retryTemplate; + return this; + } + + public Builder observationRegistry(ObservationRegistry observationRegistry) { + this.observationRegistry = observationRegistry; + return this; + } + + public CohereEmbeddingModel build() { + return new CohereEmbeddingModel(this.cohereApi, this.metadataMode, this.options, this.retryTemplate, + this.observationRegistry); + } + + } + +} diff --git a/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereEmbeddingOptions.java b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereEmbeddingOptions.java new file mode 100644 index 00000000000..6b801ea26bb --- /dev/null +++ b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereEmbeddingOptions.java @@ -0,0 +1,191 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.embedding; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.springframework.ai.cohere.api.CohereApi; +import org.springframework.ai.cohere.api.CohereApi.EmbeddingRequest.InputType; +import org.springframework.ai.cohere.api.CohereApi.EmbeddingRequest.Truncate; +import org.springframework.ai.cohere.api.CohereApi.EmbeddingType; +import org.springframework.ai.embedding.EmbeddingOptions; + +/** + * Options for the Cohere Embedding API. + * + * @author Ricken Bazolo + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public final class CohereEmbeddingOptions implements EmbeddingOptions { + + /** + * ID of the model to use. + */ + @JsonProperty("model") + private String model; + + /** + * The type of input (search_document, search_query, classification, clustering). + */ + @JsonProperty("input_type") + private InputType inputType; + + /** + * The types of embeddings to return (float, int8, uint8, binary, ubinary). + */ + @JsonProperty("embedding_types") + private List embeddingTypes = new ArrayList<>(); + + /** + * How to handle inputs longer than the maximum token length (NONE, START, END). + */ + @JsonProperty("truncate") + private Truncate truncate; + + public static Builder builder() { + return new Builder(); + } + + public static CohereEmbeddingOptions fromOptions(CohereEmbeddingOptions fromOptions) { + return builder().model(fromOptions.getModel()) + .inputType(fromOptions.getInputType()) + .embeddingTypes( + fromOptions.getEmbeddingTypes() != null ? new ArrayList<>(fromOptions.getEmbeddingTypes()) : null) + .truncate(fromOptions.getTruncate()) + .build(); + } + + private CohereEmbeddingOptions() { + this.embeddingTypes.add(EmbeddingType.FLOAT); + this.inputType = InputType.CLASSIFICATION; + } + + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + public InputType getInputType() { + return this.inputType; + } + + public void setInputType(InputType inputType) { + this.inputType = inputType; + } + + public List getEmbeddingTypes() { + return this.embeddingTypes; + } + + public void setEmbeddingTypes(List embeddingTypes) { + this.embeddingTypes = embeddingTypes; + } + + public Truncate getTruncate() { + return this.truncate; + } + + public void setTruncate(Truncate truncate) { + this.truncate = truncate; + } + + @Override + public Integer getDimensions() { + // Cohere embeddings have fixed dimensions based on model + // embed-multilingual-v3 and embed-english-v3: 1024 + // This should be handled by the model implementation + return null; + } + + public CohereEmbeddingOptions copy() { + return fromOptions(this); + } + + @Override + public int hashCode() { + return Objects.hash(this.model, this.inputType, this.embeddingTypes, this.truncate); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + CohereEmbeddingOptions that = (CohereEmbeddingOptions) o; + + return Objects.equals(this.model, that.model) && Objects.equals(this.inputType, that.inputType) + && Objects.equals(this.embeddingTypes, that.embeddingTypes) + && Objects.equals(this.truncate, that.truncate); + } + + public static final class Builder { + + private CohereEmbeddingOptions options; + + public Builder() { + this.options = new CohereEmbeddingOptions(); + } + + public Builder(CohereEmbeddingOptions options) { + this.options = options; + } + + public Builder model(String model) { + this.options.model = model; + return this; + } + + public Builder model(CohereApi.EmbeddingModel model) { + this.options.model = model.getValue(); + return this; + } + + public Builder inputType(InputType inputType) { + this.options.inputType = inputType; + return this; + } + + public Builder embeddingTypes(List embeddingTypes) { + this.options.embeddingTypes = embeddingTypes; + return this; + } + + public Builder truncate(Truncate truncate) { + this.options.truncate = truncate; + return this; + } + + public CohereEmbeddingOptions build() { + return this.options; + } + + } + +} diff --git a/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereEmbeddingUtils.java b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereEmbeddingUtils.java new file mode 100644 index 00000000000..997fb4bd89e --- /dev/null +++ b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereEmbeddingUtils.java @@ -0,0 +1,88 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.embedding; + +import java.util.Base64; +import java.util.List; + +import org.springframework.ai.content.Media; +import org.springframework.util.Assert; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; + +/** + * Utility class for Cohere embedding operations. + * + * @author Ricken Bazolo + */ +public final class CohereEmbeddingUtils { + + private static final List SUPPORTED_IMAGE_TYPES = List.of(MimeTypeUtils.IMAGE_JPEG, + MimeTypeUtils.IMAGE_PNG, MimeTypeUtils.parseMimeType("image/webp"), MimeTypeUtils.IMAGE_GIF); + + private static final long MAX_IMAGE_SIZE_BYTES = 5 * 1024 * 1024; + + private CohereEmbeddingUtils() { + } + + public static String mediaToDataUri(Media media) { + Assert.notNull(media, "Media cannot be null"); + validateImageMedia(media); + + byte[] imageData = getImageBytes(media); + validateImageSize(imageData); + + String base64Data = Base64.getEncoder().encodeToString(imageData); + String mimeType = media.getMimeType().toString(); + + return String.format("data:%s;base64,%s", mimeType, base64Data); + } + + private static void validateImageMedia(Media media) { + MimeType mimeType = media.getMimeType(); + boolean isSupported = SUPPORTED_IMAGE_TYPES.stream() + .anyMatch(supported -> mimeType.isCompatibleWith(supported)); + + if (!isSupported) { + throw new IllegalArgumentException("Unsupported image MIME type: " + mimeType + + ". Supported types: image/jpeg, image/png, image/webp, image/gif"); + } + } + + private static byte[] getImageBytes(Media media) { + Object data = media.getData(); + + if (data instanceof byte[] bytes) { + return bytes; + } + else if (data instanceof String base64String) { + return Base64.getDecoder().decode(base64String); + } + else { + throw new IllegalArgumentException("Media data must be byte[] or base64 String"); + } + } + + private static void validateImageSize(byte[] imageData) { + if (imageData.length > MAX_IMAGE_SIZE_BYTES) { + throw new IllegalArgumentException( + String.format("Image size (%d bytes) exceeds maximum allowed size (%d bytes)", imageData.length, + MAX_IMAGE_SIZE_BYTES)); + } + } + +} diff --git a/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereMultimodalEmbeddingModel.java b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereMultimodalEmbeddingModel.java new file mode 100644 index 00000000000..15ea8b5b17e --- /dev/null +++ b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereMultimodalEmbeddingModel.java @@ -0,0 +1,208 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.embedding; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import io.micrometer.observation.ObservationRegistry; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.cohere.api.CohereApi; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.DocumentEmbeddingModel; +import org.springframework.ai.embedding.DocumentEmbeddingRequest; +import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResponseMetadata; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * Implementation of the Cohere Multimodal Embedding Model. + * + * @author Ricken Bazolo + */ +public class CohereMultimodalEmbeddingModel implements DocumentEmbeddingModel { + + private static final Logger logger = LoggerFactory.getLogger(CohereMultimodalEmbeddingModel.class); + + private static final Map KNOWN_EMBEDDING_DIMENSIONS = Map.of( + CohereApi.EmbeddingModel.EMBED_MULTILINGUAL_V3.getValue(), 1024, + CohereApi.EmbeddingModel.EMBED_ENGLISH_V3.getValue(), 1024, + CohereApi.EmbeddingModel.EMBED_MULTILINGUAL_LIGHT_V3.getValue(), 384, + CohereApi.EmbeddingModel.EMBED_ENGLISH_LIGHT_V3.getValue(), 384, + CohereApi.EmbeddingModel.EMBED_V4.getValue(), 1536); + + private final CohereMultimodalEmbeddingOptions defaultOptions; + + private final CohereApi cohereApi; + + private final RetryTemplate retryTemplate; + + private final ObservationRegistry observationRegistry; + + public CohereMultimodalEmbeddingModel(CohereApi cohereApi, CohereMultimodalEmbeddingOptions options, + RetryTemplate retryTemplate, ObservationRegistry observationRegistry) { + Assert.notNull(cohereApi, "cohereApi must not be null"); + Assert.notNull(options, "options must not be null"); + Assert.notNull(retryTemplate, "retryTemplate must not be null"); + Assert.notNull(observationRegistry, "observationRegistry must not be null"); + + this.cohereApi = cohereApi; + this.defaultOptions = options; + this.retryTemplate = retryTemplate; + this.observationRegistry = observationRegistry; + } + + @Override + public EmbeddingResponse call(DocumentEmbeddingRequest request) { + CohereMultimodalEmbeddingOptions mergedOptions = mergeOptions(request.getOptions(), this.defaultOptions); + + List allEmbeddings = new ArrayList<>(); + EmbeddingResponseMetadata lastMetadata = null; + + for (Document document : request.getInstructions()) { + CohereApi.EmbeddingRequest apiRequest; + + if (document.getMedia() != null) { + apiRequest = createImageRequest(document, mergedOptions); + } + else if (StringUtils.hasText(document.getText())) { + apiRequest = createTextRequest(document, mergedOptions); + } + else { + logger.warn("Document {} has no text or media content", document.getId()); + continue; + } + + var apiResponse = RetryUtils.execute(this.retryTemplate, + () -> this.cohereApi.embeddings(apiRequest).getBody()); + + if (apiResponse != null) { + List floatEmbeddings = apiResponse.getFloatEmbeddings(); + for (int i = 0; i < floatEmbeddings.size(); i++) { + allEmbeddings.add(new Embedding(floatEmbeddings.get(i), allEmbeddings.size())); + } + lastMetadata = generateResponseMetadata(apiResponse.responseType()); + } + } + + return new EmbeddingResponse(allEmbeddings, lastMetadata); + } + + @Override + public int dimensions() { + String model = this.defaultOptions.getModel(); + if (model == null) { + return KNOWN_EMBEDDING_DIMENSIONS.get(CohereApi.EmbeddingModel.EMBED_V4.getValue()); + } + return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(model, 1024); + } + + private CohereApi.EmbeddingRequest createTextRequest(Document document, + CohereMultimodalEmbeddingOptions options) { + return CohereApi.EmbeddingRequest.builder() + .model(options.getModel()) + .inputType(CohereApi.EmbeddingRequest.InputType.CLASSIFICATION) + .embeddingTypes(options.getEmbeddingTypes()) + .texts(List.of(document.getText())) + .truncate(options.getTruncate()) + .build(); + } + + private CohereApi.EmbeddingRequest createImageRequest(Document document, + CohereMultimodalEmbeddingOptions options) { + + String dataUri = CohereEmbeddingUtils.mediaToDataUri(document.getMedia()); + + return CohereApi.EmbeddingRequest.builder() + .model(options.getModel()) + .inputType(CohereApi.EmbeddingRequest.InputType.IMAGE) + .embeddingTypes(options.getEmbeddingTypes()) + .images(List.of(dataUri)) + .truncate(options.getTruncate()) + .build(); + } + + private CohereMultimodalEmbeddingOptions mergeOptions( + org.springframework.ai.embedding.EmbeddingOptions requestOptions, + CohereMultimodalEmbeddingOptions defaultOptions) { + CohereMultimodalEmbeddingOptions options = (requestOptions != null) + ? ModelOptionsUtils.merge(requestOptions, defaultOptions, CohereMultimodalEmbeddingOptions.class) + : defaultOptions; + + if (options == null) { + throw new IllegalArgumentException("Embedding options must not be null"); + } + + return options; + } + + private EmbeddingResponseMetadata generateResponseMetadata(String embeddingType) { + return new EmbeddingResponseMetadata(embeddingType, null); + } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + + private CohereApi cohereApi; + + private CohereMultimodalEmbeddingOptions options = CohereMultimodalEmbeddingOptions.builder() + .model(CohereApi.EmbeddingModel.EMBED_V4.getValue()) + .build(); + + private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + + private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; + + public Builder cohereApi(CohereApi cohereApi) { + this.cohereApi = cohereApi; + return this; + } + + public Builder options(CohereMultimodalEmbeddingOptions options) { + this.options = options; + return this; + } + + public Builder retryTemplate(RetryTemplate retryTemplate) { + this.retryTemplate = retryTemplate; + return this; + } + + public Builder observationRegistry(ObservationRegistry observationRegistry) { + this.observationRegistry = observationRegistry; + return this; + } + + public CohereMultimodalEmbeddingModel build() { + return new CohereMultimodalEmbeddingModel(this.cohereApi, this.options, this.retryTemplate, + this.observationRegistry); + } + + } + +} diff --git a/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereMultimodalEmbeddingOptions.java b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereMultimodalEmbeddingOptions.java new file mode 100644 index 00000000000..40ad1d3ed41 --- /dev/null +++ b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereMultimodalEmbeddingOptions.java @@ -0,0 +1,177 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.embedding; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.springframework.ai.cohere.api.CohereApi; +import org.springframework.ai.cohere.api.CohereApi.EmbeddingRequest.InputType; +import org.springframework.ai.cohere.api.CohereApi.EmbeddingRequest.Truncate; +import org.springframework.ai.cohere.api.CohereApi.EmbeddingType; +import org.springframework.ai.embedding.EmbeddingOptions; + +/** + * Options for the Cohere Multimodal Embedding API. + * + * @author Ricken Bazolo + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public final class CohereMultimodalEmbeddingOptions implements EmbeddingOptions { + + @JsonProperty("model") + private String model; + + @JsonProperty("input_type") + private InputType inputType; + + @JsonProperty("embedding_types") + private List embeddingTypes = new ArrayList<>(); + + @JsonProperty("truncate") + private Truncate truncate; + + public static Builder builder() { + return new Builder(); + } + + public static CohereMultimodalEmbeddingOptions fromOptions(CohereMultimodalEmbeddingOptions fromOptions) { + return builder().model(fromOptions.getModel()) + .inputType(fromOptions.getInputType()) + .embeddingTypes( + fromOptions.getEmbeddingTypes() != null ? new ArrayList<>(fromOptions.getEmbeddingTypes()) : null) + .truncate(fromOptions.getTruncate()) + .build(); + } + + private CohereMultimodalEmbeddingOptions() { + this.embeddingTypes.add(EmbeddingType.FLOAT); + this.inputType = InputType.CLASSIFICATION; + this.model = CohereApi.EmbeddingModel.EMBED_V4.getValue(); + } + + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + public InputType getInputType() { + return this.inputType; + } + + public void setInputType(InputType inputType) { + this.inputType = inputType; + } + + public List getEmbeddingTypes() { + return this.embeddingTypes; + } + + public void setEmbeddingTypes(List embeddingTypes) { + this.embeddingTypes = embeddingTypes; + } + + public Truncate getTruncate() { + return this.truncate; + } + + public void setTruncate(Truncate truncate) { + this.truncate = truncate; + } + + @Override + public Integer getDimensions() { + return null; + } + + public CohereMultimodalEmbeddingOptions copy() { + return fromOptions(this); + } + + @Override + public int hashCode() { + return Objects.hash(this.model, this.inputType, this.embeddingTypes, this.truncate); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + CohereMultimodalEmbeddingOptions that = (CohereMultimodalEmbeddingOptions) o; + + return Objects.equals(this.model, that.model) && Objects.equals(this.inputType, that.inputType) + && Objects.equals(this.embeddingTypes, that.embeddingTypes) + && Objects.equals(this.truncate, that.truncate); + } + + public static final class Builder { + + private CohereMultimodalEmbeddingOptions options; + + public Builder() { + this.options = new CohereMultimodalEmbeddingOptions(); + } + + public Builder(CohereMultimodalEmbeddingOptions options) { + this.options = options; + } + + public Builder model(String model) { + this.options.model = model; + return this; + } + + public Builder model(CohereApi.EmbeddingModel model) { + this.options.model = model.getValue(); + return this; + } + + public Builder inputType(InputType inputType) { + this.options.inputType = inputType; + return this; + } + + public Builder embeddingTypes(List embeddingTypes) { + this.options.embeddingTypes = embeddingTypes; + return this; + } + + public Builder truncate(Truncate truncate) { + this.options.truncate = truncate; + return this; + } + + public CohereMultimodalEmbeddingOptions build() { + return this.options; + } + + } + +} diff --git a/models/spring-ai-cohere/src/main/resources/META-INF/spring/aot.factories b/models/spring-ai-cohere/src/main/resources/META-INF/spring/aot.factories new file mode 100644 index 00000000000..e6ba5bc93af --- /dev/null +++ b/models/spring-ai-cohere/src/main/resources/META-INF/spring/aot.factories @@ -0,0 +1,2 @@ +org.springframework.aot.hint.RuntimeHintsRegistrar=\ + org.springframework.ai.cohere.aot.CohereRuntimeHints diff --git a/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/CohereRetryTests.java b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/CohereRetryTests.java new file mode 100644 index 00000000000..89d4cdb2b84 --- /dev/null +++ b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/CohereRetryTests.java @@ -0,0 +1,181 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere; + +import java.util.List; +import java.util.Optional; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.cohere.api.CohereApi; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletion; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionFinishReason; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage.Role; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionRequest; +import org.springframework.ai.cohere.api.CohereApi.EmbeddingRequest; +import org.springframework.ai.cohere.api.CohereApi.EmbeddingResponse; +import org.springframework.ai.cohere.api.CohereApi.Usage; +import org.springframework.ai.cohere.chat.CohereChatModel; +import org.springframework.ai.cohere.chat.CohereChatOptions; +import org.springframework.ai.cohere.embedding.CohereEmbeddingModel; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.retry.TransientAiException; +import org.springframework.core.retry.RetryListener; +import org.springframework.core.retry.RetryPolicy; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.core.retry.Retryable; +import org.springframework.http.ResponseEntity; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.BDDMockito.given; + +/** + * @author Ricken Bazolo + */ +@SuppressWarnings("unchecked") +@ExtendWith(MockitoExtension.class) +public class CohereRetryTests { + + private TestRetryListener retryListener; + + private RetryTemplate retryTemplate; + + private @Mock CohereApi cohereApi; + + private CohereChatModel chatModel; + + private CohereEmbeddingModel embeddingModel; + + @BeforeEach + public void beforeEach() { + this.retryTemplate = RetryUtils.SHORT_RETRY_TEMPLATE; + this.retryListener = new TestRetryListener(); + this.retryTemplate.setRetryListener(this.retryListener); + + this.chatModel = CohereChatModel.builder() + .cohereApi(this.cohereApi) + .defaultOptions(CohereChatOptions.builder() + .temperature(0.7) + .topP(1.0) + .model(CohereApi.ChatModel.COMMAND_A_R7B.getValue()) + .build()) + .retryTemplate(this.retryTemplate) + .build(); + this.embeddingModel = CohereEmbeddingModel.builder() + .cohereApi(this.cohereApi) + .retryTemplate(this.retryTemplate) + .build(); + } + + @Test + public void cohereChatTransientError() { + var message = new ChatCompletionMessage.Provider( + List.of(new ChatCompletionMessage.MessageContent("text", "Response", null)), Role.ASSISTANT, null, null, + null); + + ChatCompletion expectedChatCompletion = new ChatCompletion("id", ChatCompletionFinishReason.COMPLETE, message, + null, new Usage(null, new Usage.Tokens(10, 20), 10)); + + given(this.cohereApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + .willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); + + var result = this.chatModel.call(new Prompt("text")); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput().getText()).isEqualTo("Response"); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); + assertThat(this.retryListener.retryCount).isEqualTo(2); + } + + @Test + public void cohereChatNonTransientError() { + given(this.cohereApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + .willThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt("text"))); + } + + @Test + public void cohereEmbeddingTransientError() { + List> embeddingsList = List.of(List.of(9.9, 8.8), List.of(7.7, 6.6)); + + EmbeddingResponse expectedEmbeddings = new EmbeddingResponse("id", embeddingsList, List.of("text1", "text2"), + "embeddings_floats"); + + given(this.cohereApi.embeddings(isA(EmbeddingRequest.class))) + .willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(ResponseEntity.of(Optional.of(expectedEmbeddings))); + + var result = this.embeddingModel + .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)); + + assertThat(result).isNotNull(); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); + assertThat(this.retryListener.retryCount).isEqualTo(2); + } + + @Test + public void cohereEmbeddingNonTransientError() { + given(this.cohereApi.embeddings(isA(EmbeddingRequest.class))) + .willThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> this.embeddingModel + .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null))); + } + + @Test + public void cohereChatMixedTransientAndNonTransientErrors() { + given(this.cohereApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + .willThrow(new TransientAiException("Transient Error")) + .willThrow(new RuntimeException("Non Transient Error")); + + // Should fail immediately on non-transient error, no further retries + assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt("text"))); + + // Should have 1 retry attempt before hitting non-transient error + assertThat(this.retryListener.retryCount).isEqualTo(1); + } + + private static class TestRetryListener implements RetryListener { + + int retryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onRetrySuccess(final RetryPolicy retryPolicy, final Retryable retryable, final Object result) { + // Count successful retries - we increment when we succeed after a failure + this.onSuccessRetryCount++; + } + + @Override + public void beforeRetry(RetryPolicy retryPolicy, Retryable retryable) { + this.retryCount++; + } + + } + +} diff --git a/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/CohereTestConfiguration.java b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/CohereTestConfiguration.java new file mode 100644 index 00000000000..4c03a23efb9 --- /dev/null +++ b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/CohereTestConfiguration.java @@ -0,0 +1,66 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere; + +import org.springframework.ai.cohere.api.CohereApi; +import org.springframework.ai.cohere.chat.CohereChatModel; +import org.springframework.ai.cohere.chat.CohereChatOptions; +import org.springframework.ai.cohere.embedding.CohereEmbeddingModel; +import org.springframework.ai.cohere.embedding.CohereMultimodalEmbeddingModel; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.context.annotation.Bean; +import org.springframework.util.StringUtils; + +/** + * @author Ricken Bazolo + */ +@SpringBootConfiguration +public class CohereTestConfiguration { + + private static String retrieveApiKey() { + var apiKey = System.getenv("COHERE_API_KEY"); + if (!StringUtils.hasText(apiKey)) { + throw new IllegalArgumentException( + "Missing COHERE_API_KEY environment variable. Please set it to your Cohere API key."); + } + return apiKey; + } + + @Bean + public CohereApi cohereApi() { + return CohereApi.builder().apiKey(retrieveApiKey()).build(); + } + + @Bean + public CohereChatModel cohereChatModel(CohereApi api) { + return CohereChatModel.builder() + .cohereApi(api) + .defaultOptions(CohereChatOptions.builder().model(CohereApi.ChatModel.COMMAND_A.getValue()).build()) + .build(); + } + + @Bean + public CohereEmbeddingModel cohereEmbeddingModel(CohereApi api) { + return CohereEmbeddingModel.builder().cohereApi(api).build(); + } + + @Bean + public CohereMultimodalEmbeddingModel cohereMultimodalEmbeddingModel(CohereApi api) { + return CohereMultimodalEmbeddingModel.builder().cohereApi(api).build(); + } + +} diff --git a/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/aot/CohereRuntimeHintsTests.java b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/aot/CohereRuntimeHintsTests.java new file mode 100644 index 00000000000..065edf04eeb --- /dev/null +++ b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/aot/CohereRuntimeHintsTests.java @@ -0,0 +1,245 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.aot; + +import java.util.HashSet; +import java.util.Set; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.cohere.api.CohereApi; +import org.springframework.ai.cohere.chat.CohereChatOptions; +import org.springframework.ai.cohere.embedding.CohereEmbeddingOptions; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.TypeReference; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatCode; +import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; + +class CohereRuntimeHintsTests { + + @Test + void registerHints() { + RuntimeHints runtimeHints = new RuntimeHints(); + CohereRuntimeHints cohereRuntimeHints = new CohereRuntimeHints(); + cohereRuntimeHints.registerHints(runtimeHints, null); + + Set jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage("org.springframework.ai.cohere"); + + Set registeredTypes = new HashSet<>(); + runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType())); + + for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) { + assertThat(registeredTypes.contains(jsonAnnotatedClass)).isTrue(); + } + + // Check a few more specific ones + assertThat(registeredTypes.contains(TypeReference.of(CohereApi.ChatCompletion.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(CohereApi.ChatCompletionChunk.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(CohereApi.LogProbs.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(CohereApi.ChatCompletionFinishReason.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(CohereChatOptions.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(CohereEmbeddingOptions.class))).isTrue(); + } + + @Test + void registerHintsWithNullClassLoader() { + RuntimeHints runtimeHints = new RuntimeHints(); + CohereRuntimeHints cohereRuntimeHints = new CohereRuntimeHints(); + + // Should not throw exception with null classLoader + cohereRuntimeHints.registerHints(runtimeHints, null); + + // Verify hints were registered + assertThat(runtimeHints.reflection().typeHints().count()).isGreaterThan(0); + } + + @Test + void registerHintsWithValidClassLoader() { + RuntimeHints runtimeHints = new RuntimeHints(); + CohereRuntimeHints cohereRuntimeHints = new CohereRuntimeHints(); + ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); + + cohereRuntimeHints.registerHints(runtimeHints, classLoader); + + // Verify hints were registered + assertThat(runtimeHints.reflection().typeHints().count()).isGreaterThan(0); + } + + @Test + void registerHintsIsIdempotent() { + RuntimeHints runtimeHints = new RuntimeHints(); + CohereRuntimeHints cohereRuntimeHints = new CohereRuntimeHints(); + + // Register hints twice + cohereRuntimeHints.registerHints(runtimeHints, null); + long firstCount = runtimeHints.reflection().typeHints().count(); + + cohereRuntimeHints.registerHints(runtimeHints, null); + long secondCount = runtimeHints.reflection().typeHints().count(); + + // Should have same number of hints + assertThat(firstCount).isEqualTo(secondCount); + } + + @Test + void verifyExpectedTypesAreRegistered() { + RuntimeHints runtimeHints = new RuntimeHints(); + CohereRuntimeHints cohereRuntimeHints = new CohereRuntimeHints(); + cohereRuntimeHints.registerHints(runtimeHints, null); + + Set registeredTypes = new HashSet<>(); + runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType())); + + // Verify some expected types are registered (adjust class names as needed) + assertThat(registeredTypes.stream().anyMatch(tr -> tr.getName().contains("Cohere"))).isTrue(); + assertThat(registeredTypes.stream().anyMatch(tr -> tr.getName().contains("ChatCompletion"))).isTrue(); + } + + @Test + void verifyPackageScanningWorks() { + Set jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage("org.springframework.ai.cohere"); + + // Verify package scanning found classes + assertThat(jsonAnnotatedClasses.size()).isGreaterThan(0); + } + + @Test + void verifyAllCriticalApiClassesAreRegistered() { + RuntimeHints runtimeHints = new RuntimeHints(); + CohereRuntimeHints cohereRuntimeHints = new CohereRuntimeHints(); + cohereRuntimeHints.registerHints(runtimeHints, null); + + Set registeredTypes = new HashSet<>(); + runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType())); + + // Ensure critical API classes are registered for GraalVM native image reflection + String[] criticalClasses = { "CohereApi$ChatCompletionRequest", "CohereApi$ChatCompletionMessage", + "CohereApi$EmbeddingRequest", "CohereApi$EmbeddingModel", "CohereApi$Usage" }; + + for (String className : criticalClasses) { + assertThat(registeredTypes.stream() + .anyMatch(tr -> tr.getName().contains(className.replace("$", ".")) + || tr.getName().contains(className.replace("$", "$")))) + .as("Critical class %s should be registered", className) + .isTrue(); + } + } + + @Test + void verifyEnumTypesAreRegistered() { + RuntimeHints runtimeHints = new RuntimeHints(); + CohereRuntimeHints cohereRuntimeHints = new CohereRuntimeHints(); + cohereRuntimeHints.registerHints(runtimeHints, null); + + Set registeredTypes = new HashSet<>(); + runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType())); + + // Enums are critical for JSON deserialization in native images + assertThat(registeredTypes.contains(TypeReference.of(CohereApi.ChatModel.class))) + .as("ChatModel enum should be registered") + .isTrue(); + + assertThat(registeredTypes.contains(TypeReference.of(CohereApi.EmbeddingModel.class))) + .as("EmbeddingModel enum should be registered") + .isTrue(); + } + + @Test + void verifyReflectionHintsIncludeConstructors() { + RuntimeHints runtimeHints = new RuntimeHints(); + CohereRuntimeHints cohereRuntimeHints = new CohereRuntimeHints(); + cohereRuntimeHints.registerHints(runtimeHints, null); + + // Verify that reflection hints include constructor access + boolean hasConstructorHints = runtimeHints.reflection() + .typeHints() + .anyMatch(typeHint -> typeHint.constructors().findAny().isPresent() || typeHint.getMemberCategories() + .contains(org.springframework.aot.hint.MemberCategory.INVOKE_DECLARED_CONSTRUCTORS)); + + assertThat(hasConstructorHints).as("Should register constructor hints for JSON deserialization").isTrue(); + } + + @Test + void verifyNoExceptionThrownWithEmptyRuntimeHints() { + RuntimeHints emptyRuntimeHints = new RuntimeHints(); + CohereRuntimeHints cohereRuntimeHints = new CohereRuntimeHints(); + + // Should not throw any exception even with empty runtime hints + assertThatCode(() -> cohereRuntimeHints.registerHints(emptyRuntimeHints, null)).doesNotThrowAnyException(); + + assertThat(emptyRuntimeHints.reflection().typeHints().count()).isGreaterThan(0); + } + + @Test + void verifyProxyHintsAreNotRegistered() { + RuntimeHints runtimeHints = new RuntimeHints(); + CohereRuntimeHints cohereRuntimeHints = new CohereRuntimeHints(); + cohereRuntimeHints.registerHints(runtimeHints, null); + + // MistralAi should only register reflection hints, not proxy hints + assertThat(runtimeHints.proxies().jdkProxyHints().count()).isEqualTo(0); + } + + @Test + void verifySerializationHintsAreNotRegistered() { + RuntimeHints runtimeHints = new RuntimeHints(); + CohereRuntimeHints cohereRuntimeHints = new CohereRuntimeHints(); + cohereRuntimeHints.registerHints(runtimeHints, null); + + // MistralAi should only register reflection hints, not serialization hints + assertThat(runtimeHints.serialization().javaSerializationHints().count()).isEqualTo(0); + } + + @Test + void verifyResponseTypesAreRegistered() { + RuntimeHints runtimeHints = new RuntimeHints(); + CohereRuntimeHints cohereRuntimeHints = new CohereRuntimeHints(); + cohereRuntimeHints.registerHints(runtimeHints, null); + + Set registeredTypes = new HashSet<>(); + runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType())); + + // Verify response wrapper types are registered + assertThat(registeredTypes.stream().anyMatch(tr -> tr.getName().contains("EmbeddingResponse"))) + .as("EmbeddingResponse type should be registered") + .isTrue(); + + assertThat(registeredTypes.stream().anyMatch(tr -> tr.getName().contains("ChatCompletion"))) + .as("ChatCompletion response type should be registered") + .isTrue(); + } + + @Test + void verifyMultipleInstancesRegisterSameHints() { + RuntimeHints runtimeHints1 = new RuntimeHints(); + RuntimeHints runtimeHints2 = new RuntimeHints(); + + CohereRuntimeHints hints1 = new CohereRuntimeHints(); + CohereRuntimeHints hints2 = new CohereRuntimeHints(); + + hints1.registerHints(runtimeHints1, null); + hints2.registerHints(runtimeHints2, null); + + long count1 = runtimeHints1.reflection().typeHints().count(); + long count2 = runtimeHints2.reflection().typeHints().count(); + + assertThat(count1).isEqualTo(count2); + } + +} diff --git a/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/api/CohereApiIT.java b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/api/CohereApiIT.java new file mode 100644 index 00000000000..656f268ad36 --- /dev/null +++ b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/api/CohereApiIT.java @@ -0,0 +1,92 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.api; + +import java.util.List; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + +import org.springframework.ai.cohere.CohereTestConfiguration; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletion; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage.Role; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionRequest; +import org.springframework.ai.cohere.testutils.AbstractIT; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.http.ResponseEntity; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Ricken Bazolo + */ +@SpringBootTest(classes = CohereTestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "COHERE_API_KEY", matches = ".+") +class CohereApiIT extends AbstractIT { + + @Test + void chatCompletionEntity() { + ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); + ResponseEntity response = this.cohereApi.chatCompletionEntity(new ChatCompletionRequest( + List.of(chatCompletionMessage), CohereApi.ChatModel.COMMAND_A_R7B.getValue(), 0.8, false)); + + assertThat(response).isNotNull(); + assertThat(response.getBody()).isNotNull(); + } + + @Test + void chatCompletionEntityWithSystemMessage() { + ChatCompletionMessage userMessage = new ChatCompletionMessage( + "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did?", Role.USER); + ChatCompletionMessage systemMessage = new ChatCompletionMessage(""" + You are an AI assistant that helps people find information. + Your name is Bob. + You should reply to the user's request with your name and also in the style of a pirate. + """, Role.SYSTEM); + + ResponseEntity response = this.cohereApi.chatCompletionEntity(new ChatCompletionRequest( + List.of(systemMessage, userMessage), CohereApi.ChatModel.COMMAND_A_R7B.getValue(), 0.8, false)); + + assertThat(response).isNotNull(); + assertThat(response.getBody()).isNotNull(); + } + + @Test + void embeddings() { + ResponseEntity response = this.cohereApi + .embeddings(CohereApi.EmbeddingRequest.builder().texts("Hello world").build()); + + assertThat(response).isNotNull(); + Assertions.assertNotNull(response.getBody()); + assertThat(response.getBody().getFloatEmbeddings()).hasSize(1); + assertThat(response.getBody().getFloatEmbeddings().get(0)).hasSize(1536); + } + + @Test + void chatCompletionStream() { + ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); + Flux response = this.cohereApi.chatCompletionStream(new ChatCompletionRequest( + List.of(chatCompletionMessage), CohereApi.ChatModel.COMMAND_A_R7B.getValue(), 0.8, true)); + + assertThat(response).isNotNull(); + assertThat(response.collectList().block()).isNotNull(); + } + +} diff --git a/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/api/tool/CohereApiToolFunctionCallIT.java b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/api/tool/CohereApiToolFunctionCallIT.java new file mode 100644 index 00000000000..250f78d0bca --- /dev/null +++ b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/api/tool/CohereApiToolFunctionCallIT.java @@ -0,0 +1,158 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.api.tool; + +import java.util.ArrayList; +import java.util.List; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.cohere.api.CohereApi; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletion; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage.Role; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage.ToolCall; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionRequest; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionRequest.ToolChoice; +import org.springframework.ai.cohere.api.CohereApi.FunctionTool.Type; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.http.ResponseEntity; +import org.springframework.util.ObjectUtils; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Ricken Bazolo + */ +@EnabledIfEnvironmentVariable(named = "COHERE_API_KEY", matches = ".+") +public class CohereApiToolFunctionCallIT { + + static final String MISTRAL_AI_CHAT_MODEL = CohereApi.ChatModel.COMMAND_A_R7B.getValue(); + + private final Logger logger = LoggerFactory.getLogger(CohereApiToolFunctionCallIT.class); + + MockWeatherService weatherService = new MockWeatherService(); + + CohereApi completionApi = CohereApi.builder().apiKey(System.getenv("COHERE_API_KEY")).build(); + + private static T fromJson(String json, Class targetClass) { + try { + return new ObjectMapper().readValue(json, targetClass); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + @Test + @SuppressWarnings("null") + public void toolFunctionCall() throws JsonProcessingException { + + // Step 1: send the conversation and available functions to the model + var message = new ChatCompletionMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? Show the temperature in Celsius.", + Role.USER); + + var functionTool = new CohereApi.FunctionTool(Type.FUNCTION, + new CohereApi.FunctionTool.Function( + "Get the weather in location. Return temperature in 30°F or 30°C format.", "getCurrentWeather", + ModelOptionsUtils.jsonToMap(""" + { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["C", "F"] + } + }, + "required": ["location", "unit"] + } + """))); + + List messages = new ArrayList<>(List.of(message)); + + ChatCompletionRequest chatCompletionRequest = new ChatCompletionRequest(messages, MISTRAL_AI_CHAT_MODEL, + List.of(functionTool), ToolChoice.REQUIRED); + + System.out + .println(new ObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(chatCompletionRequest)); + + ResponseEntity response = this.completionApi.chatCompletionEntity(chatCompletionRequest); + + ChatCompletion chatCompletion = response.getBody(); + + assertThat(chatCompletion).isNotNull(); + assertThat(chatCompletion.message()).isNotNull(); + + ChatCompletionMessage responseMessage = new ChatCompletionMessage(chatCompletion.message().content(), + chatCompletion.message().role(), chatCompletion.message().toolPlan(), + chatCompletion.message().toolCalls(), chatCompletion.message().citations(), null); + + assertThat(responseMessage.role()).isEqualTo(Role.ASSISTANT); + assertThat(responseMessage.toolCalls()).isNotNull(); + + // Check if the model wanted to call a function + if (!ObjectUtils.isEmpty(responseMessage.toolCalls())) { + + // extend conversation with assistant's reply. + messages.add(responseMessage); + + // Send the info for each function call and function response to the model. + for (ToolCall toolCall : responseMessage.toolCalls()) { + var functionName = toolCall.function().name(); + if ("getCurrentWeather".equals(functionName)) { + MockWeatherService.Request weatherRequest = fromJson(toolCall.function().arguments(), + MockWeatherService.Request.class); + + MockWeatherService.Response weatherResponse = this.weatherService.apply(weatherRequest); + + // extend conversation with function response. + messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), + Role.TOOL, functionName, null, responseMessage.citations(), toolCall.id())); + } + } + + var functionResponseRequest = new ChatCompletionRequest(messages, MISTRAL_AI_CHAT_MODEL, 0.8); + + ResponseEntity result2 = this.completionApi.chatCompletionEntity(functionResponseRequest); + + chatCompletion = result2.getBody(); + + logger.info("Final response: {}", chatCompletion); + + assertThat(chatCompletion.message().content()).isNotEmpty(); + + var messageContent = chatCompletion.message().content().get(0); + + assertThat(chatCompletion.message().role()).isEqualTo(Role.ASSISTANT); + assertThat(messageContent.text()).contains("San Francisco").containsAnyOf("30.0", "30"); + assertThat(messageContent.text()).contains("Tokyo").containsAnyOf("10.0", "10"); + assertThat(messageContent.text()).contains("Paris").containsAnyOf("15.0", "15"); + } + + } + +} diff --git a/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/api/tool/MockWeatherService.java b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/api/tool/MockWeatherService.java new file mode 100644 index 00000000000..c0b96608a7d --- /dev/null +++ b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/api/tool/MockWeatherService.java @@ -0,0 +1,90 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.api.tool; + +import java.util.function.Function; + +import com.fasterxml.jackson.annotation.JsonClassDescription; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +public class MockWeatherService implements Function { + + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); + } + + /** + * Temperature units. + */ + public enum Unit { + + /** + * Celsius. + */ + C("metric"), + /** + * Fahrenheit. + */ + F("imperial"); + + /** + * Human readable unit name. + */ + public final String unitName; + + Unit(String text) { + this.unitName = text; + } + + } + + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + + } + + /** + * Weather Function response. + */ + public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, + Unit unit) { + + } + +} diff --git a/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/api/tool/PaymentStatusFunctionCallingIT.java b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/api/tool/PaymentStatusFunctionCallingIT.java new file mode 100644 index 00000000000..2422154b458 --- /dev/null +++ b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/api/tool/PaymentStatusFunctionCallingIT.java @@ -0,0 +1,171 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.api.tool; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.cohere.api.CohereApi; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletion; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage.Role; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage.ToolCall; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionRequest; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionRequest.ToolChoice; +import org.springframework.ai.cohere.api.CohereApi.FunctionTool; +import org.springframework.ai.cohere.api.CohereApi.FunctionTool.Type; +import org.springframework.http.ResponseEntity; + +import static org.assertj.core.api.Assertions.assertThat; + +@EnabledIfEnvironmentVariable(named = "COHERE_API_KEY", matches = ".+") +public class PaymentStatusFunctionCallingIT { + + // Assuming we have the following data + public static final Map DATA = Map.of("T1001", new StatusDate("Paid", "2021-10-05"), "T1002", + new StatusDate("Unpaid", "2021-10-06"), "T1003", new StatusDate("Paid", "2021-10-07"), "T1004", + new StatusDate("Paid", "2021-10-05"), "T1005", new StatusDate("Pending", "2021-10-08")); + + static Map> functions = Map.of("retrieve_payment_status", + new RetrievePaymentStatus(), "retrieve_payment_date", new RetrievePaymentDate()); + + private final Logger logger = LoggerFactory.getLogger(PaymentStatusFunctionCallingIT.class); + + private static T jsonToObject(String json, Class targetClass) { + try { + return new ObjectMapper().readValue(json, targetClass); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + @Test + @SuppressWarnings("null") + public void toolFunctionCall() throws JsonProcessingException { + + var transactionJsonSchema = """ + { + "type": "object", + "properties": { + "transaction_id": { + "type": "string", + "description": "The transaction id" + } + }, + "required": ["transaction_id"] + } + """; + + var paymentStatusTool = new FunctionTool(Type.FUNCTION, new FunctionTool.Function( + "Get payment status of a transaction", "retrieve_payment_status", transactionJsonSchema)); + + var paymentDateTool = new FunctionTool(Type.FUNCTION, new FunctionTool.Function( + "Get payment date of a transaction", "retrieve_payment_date", transactionJsonSchema)); + + List messages = new ArrayList<>( + List.of(new ChatCompletionMessage("What's the status of my transaction with id T1001?", Role.USER))); + + CohereApi cohereApi = CohereApi.builder().apiKey(System.getenv("COHERE_API_KEY")).build(); + + ResponseEntity response = cohereApi + .chatCompletionEntity(new ChatCompletionRequest(messages, CohereApi.ChatModel.COMMAND_A_R7B.getValue(), + List.of(paymentStatusTool, paymentDateTool), ToolChoice.REQUIRED)); + + ChatCompletion chatCompletion = response.getBody(); + + ChatCompletionMessage responseMessage = new ChatCompletionMessage(chatCompletion.message().content(), + chatCompletion.message().role(), chatCompletion.message().toolPlan(), + chatCompletion.message().toolCalls(), chatCompletion.message().citations(), null); + + assertThat(responseMessage.role()).isEqualTo(Role.ASSISTANT); + assertThat(responseMessage.toolCalls()).isNotNull(); + + // extend conversation with assistant's reply. + messages.add(responseMessage); + + // Send the info for each function call and function response to the model. + for (ToolCall toolCall : responseMessage.toolCalls()) { + + var functionName = toolCall.function().name(); + // Map the function, JSON arguments into a Transaction object. + Transaction transaction = jsonToObject(toolCall.function().arguments(), Transaction.class); + // Call the target function with the transaction object. + var result = functions.get(functionName).apply(transaction); + + // Extend conversation with function response. + // The functionName is used to identify the function response! + messages.add(new ChatCompletionMessage(result.toString(), Role.TOOL, functionName, null, + responseMessage.citations(), toolCall.id())); + } + + response = cohereApi + .chatCompletionEntity(new ChatCompletionRequest(messages, CohereApi.ChatModel.COMMAND_A_R7B.getValue())); + + chatCompletion = response.getBody(); + var content = chatCompletion.message().content().get(0).text(); + logger.info("Final response: {}", content); + + assertThat(content).containsIgnoringCase("T1001"); + assertThat(content).containsIgnoringCase("Paid"); + } + + record StatusDate(String status, String date) { + + } + + public record Transaction(@JsonProperty(required = true, value = "transaction_id") String transactionId) { + + } + + public record Status(@JsonProperty(required = true, value = "status") String status) { + + } + + public record Date(@JsonProperty(required = true, value = "date") String date) { + + } + + private static class RetrievePaymentStatus implements Function { + + @Override + public Status apply(Transaction paymentTransaction) { + return new Status(DATA.get(paymentTransaction.transactionId).status); + } + + } + + private static class RetrievePaymentDate implements Function { + + @Override + public Date apply(Transaction paymentTransaction) { + return new Date(DATA.get(paymentTransaction.transactionId).date); + } + + } + +} diff --git a/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/chat/CohereChatClientIT.java b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/chat/CohereChatClientIT.java new file mode 100644 index 00000000000..7eedb9194cf --- /dev/null +++ b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/chat/CohereChatClientIT.java @@ -0,0 +1,290 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.chat; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.cohere.CohereTestConfiguration; +import org.springframework.ai.cohere.api.CohereApi; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionRequest.ToolChoice; +import org.springframework.ai.cohere.api.tool.MockWeatherService; +import org.springframework.ai.cohere.testutils.AbstractIT; +import org.springframework.ai.converter.BeanOutputConverter; +import org.springframework.ai.converter.ListOutputConverter; +import org.springframework.ai.test.CurlyBracketEscaper; +import org.springframework.ai.tool.function.FunctionToolCallback; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.convert.support.DefaultConversionService; +import org.springframework.core.io.Resource; + +import static org.assertj.core.api.Assertions.assertThat; + +@SpringBootTest(classes = CohereTestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "COHERE_API_KEY", matches = ".+") +class CohereChatClientIT extends AbstractIT { + + private static final Logger logger = LoggerFactory.getLogger(CohereChatClientIT.class); + + @Value("classpath:/prompts/system-message.st") + private Resource systemTextResource; + + @Test + void call() { + // @formatter:off + ChatResponse response = ChatClient.create(this.chatModel).prompt() + .system(s -> s.text(this.systemTextResource) + .param("name", "Bob") + .param("voice", "pirate")) + .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") + .call() + .chatResponse(); + // @formatter:on + + logger.info("{}", response); + assertThat(response.getResults()).hasSize(1); + assertThat(response.getResults().get(0).getOutput().getText()).contains("Blackbeard"); + } + + @Test + void testMessageHistory() { + + // @formatter:off + ChatResponse response = ChatClient.create(this.chatModel).prompt() + .system(s -> s.text(this.systemTextResource) + .param("name", "Bob") + .param("voice", "pirate")) + .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") + .call() + .chatResponse(); + // @formatter:on + assertThat(response.getResult().getOutput().getText()).containsAnyOf("Blackbeard"); + + // @formatter:off + response = ChatClient.create(this.chatModel).prompt() + .messages(List.of(new UserMessage("Dummy"), response.getResult().getOutput())) + .user("Repeat the last assistant message.") + .call() + .chatResponse(); + // @formatter:on + + logger.info("" + response); + assertThat(response.getResult().getOutput().getText().toLowerCase()).containsAnyOf("blackbeard", + "bartholomew roberts"); + } + + @Test + void listOutputConverterString() { + // @formatter:off + List collection = ChatClient.create(this.chatModel).prompt() + .user(u -> u.text("List five {subject}") + .param("subject", "ice cream flavors")) + .call() + .entity(new ParameterizedTypeReference<>() { }); + // @formatter:on + + logger.info(collection.toString()); + assertThat(collection).hasSize(5); + } + + @Test + void listOutputConverterBean() { + + // @formatter:off + List actorsFilms = ChatClient.create(this.chatModel).prompt() + .user("Generate the filmography of 5 movies for Tom Hanks and Bill Murray.") + .call() + .entity(new ParameterizedTypeReference<>() { + }); + // @formatter:on + + logger.info("" + actorsFilms); + assertThat(actorsFilms).hasSize(2); + } + + @Test + void customOutputConverter() { + + var toStringListConverter = new ListOutputConverter(new DefaultConversionService()); + + // @formatter:off + List flavors = ChatClient.create(this.chatModel).prompt() + .user(u -> u.text("List 10 {subject}") + .param("subject", "ice cream flavors")) + .call() + .entity(toStringListConverter); + // @formatter:on + + logger.info("ice cream flavors{}", flavors); + assertThat(flavors).hasSize(10); + assertThat(flavors).containsAnyOf("Vanilla", "vanilla"); + } + + @Test + void mapOutputConverter() { + // @formatter:off + Map result = ChatClient.create(this.chatModel).prompt() + .user(u -> u.text("Provide me a List of {subject}") + .param("subject", "an array of numbers from 1 to 9 under they key name 'numbers'")) + .call() + .entity(new ParameterizedTypeReference<>() { + }); + // @formatter:on + + assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); + } + + @Test + void beanOutputConverter() { + + // @formatter:off + ActorsFilms actorsFilms = ChatClient.create(this.chatModel).prompt() + .user("Generate the filmography for a random actor.") + .call() + .entity(ActorsFilms.class); + // @formatter:on + + logger.info("{}", actorsFilms); + assertThat(actorsFilms.actor()).isNotBlank(); + } + + @Test + void beanOutputConverterRecords() { + + // @formatter:off + ActorsFilms actorsFilms = ChatClient.create(this.chatModel).prompt() + .user("Generate the filmography of 5 movies for Tom Hanks.") + .call() + .entity(ActorsFilms.class); + // @formatter:on + + logger.info("{}", actorsFilms); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + + @Test + void beanStreamOutputConverterRecords() { + + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilms.class); + + // @formatter:off + Flux chatResponse = ChatClient.create(this.chatModel) + .prompt() + .advisors(new SimpleLoggerAdvisor()) + .user(u -> u + .text("Generate the filmography of 5 movies for Tom Hanks. " + System.lineSeparator() + + "{format}") + .param("format", CurlyBracketEscaper.escapeCurlyBrackets(outputConverter.getFormat()))) + .stream() + .content(); + + String generationTextFromStream = chatResponse.collectList() + .block() + .stream() + .collect(Collectors.joining()); + // @formatter:on + + ActorsFilms actorsFilms = outputConverter.convert(generationTextFromStream); + + logger.info("{}", actorsFilms); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + + @Test + void functionCallTest() { + + // @formatter:off + String response = ChatClient.create(this.chatModel).prompt() + .options(CohereChatOptions.builder().model(CohereApi.ChatModel.COMMAND_A_R7B).toolChoice(ToolChoice.REQUIRED).build()) + .user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius.")) + .toolCallbacks(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build()) + .call() + .content(); + // @formatter:on + + logger.info("Response: {}", response); + + assertThat(response).containsAnyOf("30.0", "30"); + assertThat(response).containsAnyOf("10.0", "10"); + assertThat(response).containsAnyOf("15.0", "15"); + } + + @Test + void streamFunctionCallTest() { + + // @formatter:off + Flux response = ChatClient.create(this.chatModel).prompt() + .options(CohereChatOptions.builder().model(CohereApi.ChatModel.COMMAND_A_R7B).build()) + .user("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius.") + .toolCallbacks(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build()) + .stream() + .content(); + // @formatter:on + + String content = response.collectList().block().stream().collect(Collectors.joining()); + logger.info("Response: {}", content); + + assertThat(content).containsAnyOf("30.0", "30"); + assertThat(content).containsAnyOf("10.0", "10"); + assertThat(content).containsAnyOf("15.0", "15"); + } + + @Test + void validateCallResponseMetadata() { + String model = CohereApi.ChatModel.COMMAND_A_R7B.getName(); + // @formatter:off + ChatResponse response = ChatClient.create(this.chatModel).prompt() + .options(CohereChatOptions.builder().model(model).build()) + .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") + .call() + .chatResponse(); + // @formatter:on + + logger.info(response.toString()); + assertThat(response.getMetadata().getId()).isNotEmpty(); + assertThat(response.getMetadata().getUsage().getPromptTokens()).isPositive(); + assertThat(response.getMetadata().getUsage().getCompletionTokens()).isPositive(); + assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); + } + + record ActorsFilms(String actor, List movies) { + + } + +} diff --git a/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/chat/CohereChatCompletionRequestTests.java b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/chat/CohereChatCompletionRequestTests.java new file mode 100644 index 00000000000..80dda662b01 --- /dev/null +++ b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/chat/CohereChatCompletionRequestTests.java @@ -0,0 +1,316 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.chat; + +import java.net.URI; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.chat.messages.AbstractMessage; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.cohere.api.CohereApi; +import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage; +import org.springframework.ai.content.Media; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.DefaultToolDefinition; +import org.springframework.ai.tool.definition.ToolDefinition; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * @author Ricken Bazolo + */ +class CohereChatCompletionRequestTests { + + private static final String BASE_URL = "https://faked.url"; + + private static final String API_KEY = "FAKED_API_KEY"; + + private static final String TEXT_CONTENT = "Hello world!"; + + private static final String IMAGE_URL = "https://example.com/image.png"; + + private static final Media IMAGE_MEDIA = new Media(Media.Format.IMAGE_PNG, URI.create(IMAGE_URL)); + + private final CohereChatModel chatModel = CohereChatModel.builder() + .cohereApi(CohereApi.builder().baseUrl(BASE_URL).apiKey(API_KEY).build()) + .build(); + + @Test + void chatCompletionDefaultRequestTest() { + var prompt = this.chatModel.buildRequestPrompt(new Prompt("test content")); + var request = this.chatModel.createRequest(prompt, false); + + assertThat(request.messages()).hasSize(1); + assertThat(request.temperature()).isEqualTo(0.3); + assertThat(request.p()).isEqualTo(1); + assertThat(request.maxTokens()).isNull(); + assertThat(request.stream()).isFalse(); + } + + @Test + void chatCompletionRequestWithOptionsTest() { + var options = CohereChatOptions.builder().temperature(0.5).topP(0.8).build(); + var prompt = this.chatModel.buildRequestPrompt(new Prompt("test content", options)); + var request = this.chatModel.createRequest(prompt, true); + + assertThat(request.messages()).hasSize(1); + assertThat(request.p()).isEqualTo(0.8); + assertThat(request.temperature()).isEqualTo(0.5); + assertThat(request.stream()).isTrue(); + } + + @Test + void whenToolRuntimeOptionsThenMergeWithDefaults() { + CohereChatOptions defaultOptions = CohereChatOptions.builder() + .model("DEFAULT_MODEL") + .internalToolExecutionEnabled(true) + .toolCallbacks(new TestToolCallback("tool1"), new TestToolCallback("tool2")) + .toolNames("tool1", "tool2") + .toolContext(Map.of("key1", "value1", "key2", "valueA")) + .build(); + + CohereChatModel anotherChatModel = CohereChatModel.builder() + .cohereApi(CohereApi.builder().baseUrl(BASE_URL).apiKey(API_KEY).build()) + .defaultOptions(defaultOptions) + .build(); + + CohereChatOptions runtimeOptions = CohereChatOptions.builder() + .internalToolExecutionEnabled(false) + .toolCallbacks(new TestToolCallback("tool3"), new TestToolCallback("tool4")) + .toolNames("tool3") + .toolContext(Map.of("key2", "valueB")) + .build(); + Prompt prompt = anotherChatModel.buildRequestPrompt(new Prompt("Test message content", runtimeOptions)); + + assertThat(((ToolCallingChatOptions) prompt.getOptions())).isNotNull(); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getInternalToolExecutionEnabled()).isFalse(); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks()).hasSize(2); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks() + .stream() + .map(toolCallback -> toolCallback.getToolDefinition().name())).containsExactlyInAnyOrder("tool3", "tool4"); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolNames()).containsExactlyInAnyOrder("tool3"); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolContext()).containsEntry("key1", "value1") + .containsEntry("key2", "valueB"); + } + + @Test + void createChatCompletionMessagesWithUserMessage() { + var userMessage = new UserMessage(TEXT_CONTENT); + userMessage.getMedia().add(IMAGE_MEDIA); + var prompt = createPrompt(userMessage); + var chatCompletionRequest = this.chatModel.createRequest(prompt, false); + verifyUserChatCompletionMessages(chatCompletionRequest.messages()); + } + + @Test + void createChatCompletionMessagesWithSimpleUserMessage() { + var simpleUserMessage = new SimpleMessage(MessageType.USER, TEXT_CONTENT); + var prompt = createPrompt(simpleUserMessage); + var chatCompletionRequest = this.chatModel.createRequest(prompt, false); + var chatCompletionMessages = chatCompletionRequest.messages(); + assertThat(chatCompletionMessages).hasSize(1); + var chatCompletionMessage = chatCompletionMessages.get(0); + assertThat(chatCompletionMessage.role()).isEqualTo(ChatCompletionMessage.Role.USER); + assertThat(chatCompletionMessage.content()).isEqualTo(TEXT_CONTENT); + } + + @Test + void createChatCompletionMessagesWithSystemMessage() { + var systemMessage = new SystemMessage(TEXT_CONTENT); + var prompt = createPrompt(systemMessage); + var chatCompletionRequest = this.chatModel.createRequest(prompt, false); + verifySystemChatCompletionMessages(chatCompletionRequest.messages()); + } + + @Test + void createChatCompletionMessagesWithSimpleSystemMessage() { + var simpleSystemMessage = new SimpleMessage(MessageType.SYSTEM, TEXT_CONTENT); + var prompt = createPrompt(simpleSystemMessage); + var chatCompletionRequest = this.chatModel.createRequest(prompt, false); + verifySystemChatCompletionMessages(chatCompletionRequest.messages()); + } + + @Test + void createChatCompletionMessagesWithAssistantMessage() { + var toolCall1 = createToolCall(1); + var toolCall2 = createToolCall(2); + var toolCall3 = createToolCall(3); + // @formatter:off + var assistantMessage = AssistantMessage.builder() + .content(TEXT_CONTENT) + .toolCalls(List.of(toolCall1, toolCall2, toolCall3)) + .build(); + // @formatter:on + var prompt = createPrompt(assistantMessage); + var chatCompletionRequest = this.chatModel.createRequest(prompt, false); + var chatCompletionMessages = chatCompletionRequest.messages(); + assertThat(chatCompletionMessages).hasSize(1); + var chatCompletionMessage = chatCompletionMessages.get(0); + assertThat(chatCompletionMessage.role()).isEqualTo(ChatCompletionMessage.Role.ASSISTANT); + assertThat(chatCompletionMessage.content()).isEqualTo(TEXT_CONTENT); + var toolCalls = chatCompletionMessage.toolCalls(); + assertThat(toolCalls).hasSize(3); + verifyToolCall(toolCalls.get(0), toolCall1); + verifyToolCall(toolCalls.get(1), toolCall2); + verifyToolCall(toolCalls.get(2), toolCall3); + } + + @Test + void createChatCompletionMessagesWithSimpleAssistantMessage() { + var simpleAssistantMessage = new SimpleMessage(MessageType.ASSISTANT, TEXT_CONTENT); + var prompt = createPrompt(simpleAssistantMessage); + assertThatThrownBy(() -> this.chatModel.createRequest(prompt, false)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Unsupported assistant message class: " + SimpleMessage.class.getName()); + } + + @Test + void createChatCompletionMessagesWithToolResponseMessage() { + var toolResponse1 = createToolResponse(1); + var toolResponse2 = createToolResponse(2); + var toolResponse3 = createToolResponse(3); + var toolResponseMessage = ToolResponseMessage.builder() + .responses(List.of(toolResponse1, toolResponse2, toolResponse3)) + .build(); + var prompt = createPrompt(toolResponseMessage); + var chatCompletionRequest = this.chatModel.createRequest(prompt, false); + var chatCompletionMessages = chatCompletionRequest.messages(); + assertThat(chatCompletionMessages).hasSize(3); + verifyToolChatCompletionMessage(chatCompletionMessages.get(0), toolResponse1); + verifyToolChatCompletionMessage(chatCompletionMessages.get(1), toolResponse2); + verifyToolChatCompletionMessage(chatCompletionMessages.get(2), toolResponse3); + } + + @Test + void createChatCompletionMessagesWithInvalidToolResponseMessage() { + var toolResponse = new ToolResponseMessage.ToolResponse(null, null, null); + var toolResponseMessage = ToolResponseMessage.builder().responses(List.of(toolResponse)).build(); + var prompt = createPrompt(toolResponseMessage); + assertThatThrownBy(() -> this.chatModel.createRequest(prompt, false)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("ToolResponseMessage.ToolResponse must have an id"); + } + + @Test + void createChatCompletionMessagesWithSimpleToolMessage() { + var simpleToolMessage = new SimpleMessage(MessageType.TOOL, TEXT_CONTENT); + var prompt = createPrompt(simpleToolMessage); + assertThatThrownBy(() -> this.chatModel.createRequest(prompt, false)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Unsupported tool message class: " + SimpleMessage.class.getName()); + } + + private Prompt createPrompt(Message message) { + var chatOptions = CohereChatOptions.builder().temperature(0.7d).build(); + var prompt = new Prompt(message, chatOptions); + + return this.chatModel.buildRequestPrompt(prompt); + } + + private static void verifyToolChatCompletionMessage(ChatCompletionMessage chatCompletionMessage, + ToolResponseMessage.ToolResponse toolResponse) { + assertThat(chatCompletionMessage.role()).isEqualTo(ChatCompletionMessage.Role.TOOL); + assertThat(chatCompletionMessage.content()).isEqualTo(toolResponse.responseData()); + assertThat(chatCompletionMessage.toolCalls()).isNull(); + assertThat(chatCompletionMessage.toolCallId()).isEqualTo(toolResponse.id()); + } + + private static ToolResponseMessage.ToolResponse createToolResponse(int number) { + return new ToolResponseMessage.ToolResponse("id" + number, "name" + number, "responseData" + number); + } + + private static void verifyToolCall(ChatCompletionMessage.ToolCall mistralToolCall, + AssistantMessage.ToolCall toolCall) { + assertThat(mistralToolCall.id()).isEqualTo(toolCall.id()); + assertThat(mistralToolCall.type()).isEqualTo(toolCall.type()); + var function = mistralToolCall.function(); + assertThat(function).isNotNull(); + assertThat(function.name()).isEqualTo(toolCall.name()); + assertThat(function.arguments()).isEqualTo(toolCall.arguments()); + } + + private static AssistantMessage.ToolCall createToolCall(int number) { + return new AssistantMessage.ToolCall("id" + number, "type" + number, "name" + number, "arguments " + number); + } + + private static void verifySystemChatCompletionMessages(List chatCompletionMessages) { + assertThat(chatCompletionMessages).hasSize(1); + var chatCompletionMessage = chatCompletionMessages.get(0); + assertThat(chatCompletionMessage.role()).isEqualTo(ChatCompletionMessage.Role.SYSTEM); + assertThat(chatCompletionMessage.content()).isEqualTo(TEXT_CONTENT); + } + + private static void verifyUserChatCompletionMessages(List chatCompletionMessages) { + assertThat(chatCompletionMessages).hasSize(1); + var chatCompletionMessage = chatCompletionMessages.get(0); + assertThat(chatCompletionMessage.role()).isEqualTo(ChatCompletionMessage.Role.USER); + var rawContent = chatCompletionMessage.rawContent(); + assertThat(rawContent).isNotNull(); + var maps = (List>) rawContent; + assertThat(maps).hasSize(2); + // @formatter:off + var textMap = maps.get(0); + assertThat(textMap).hasSize(2) + .containsEntry("type", "text") + .containsEntry("text", TEXT_CONTENT); + var imageUrlMap = maps.get(1); + assertThat(imageUrlMap).hasSize(2) + .containsEntry("type", "image_url") + .containsEntry("image_url", Map.of("url", IMAGE_URL)); + // @formatter:on + } + + static class SimpleMessage extends AbstractMessage { + + SimpleMessage(MessageType messageType, String textContent) { + super(messageType, textContent, Map.of()); + } + + } + + static class TestToolCallback implements ToolCallback { + + private final ToolDefinition toolDefinition; + + TestToolCallback(String name) { + this.toolDefinition = DefaultToolDefinition.builder().name(name).inputSchema("{}").build(); + } + + @Override + public ToolDefinition getToolDefinition() { + return this.toolDefinition; + } + + @Override + public String call(String toolInput) { + return "Mission accomplished!"; + } + + } + +} diff --git a/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/chat/CohereChatModelIT.java b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/chat/CohereChatModelIT.java new file mode 100644 index 00000000000..81798284e5b --- /dev/null +++ b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/chat/CohereChatModelIT.java @@ -0,0 +1,368 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.chat; + +import java.io.IOException; +import java.net.URI; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.memory.MessageWindowChatMemory; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.cohere.CohereTestConfiguration; +import org.springframework.ai.cohere.api.CohereApi; +import org.springframework.ai.cohere.api.CohereApi.ChatModel; +import org.springframework.ai.cohere.api.tool.MockWeatherService; +import org.springframework.ai.cohere.testutils.AbstractIT; +import org.springframework.ai.content.Media; +import org.springframework.ai.converter.BeanOutputConverter; +import org.springframework.ai.converter.ListOutputConverter; +import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.ai.model.tool.DefaultToolCallingManager; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionResult; +import org.springframework.ai.support.ToolCallbacks; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.tool.function.FunctionToolCallback; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.core.convert.support.DefaultConversionService; +import org.springframework.util.MimeTypeUtils; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Ricken Bazolo + */ +@SpringBootTest(classes = CohereTestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "COHERE_API_KEY", matches = ".+") +class CohereChatModelIT extends AbstractIT { + + private static final Logger logger = LoggerFactory.getLogger(CohereChatModelIT.class); + + @Test + void roleTest() { + UserMessage userMessage = new UserMessage( + "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); + Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); + Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + ChatResponse response = this.chatModel.call(prompt); + assertThat(response.getResults()).hasSize(1); + assertThat(response.getResults().get(0).getOutput().getText()).contains("Blackbeard"); + } + + @Test + void listOutputConverter() { + DefaultConversionService conversionService = new DefaultConversionService(); + ListOutputConverter outputConverter = new ListOutputConverter(conversionService); + + String format = outputConverter.getFormat(); + String template = """ + List five {subject} + {format} + """; + PromptTemplate promptTemplate = PromptTemplate.builder() + .template(template) + .variables(Map.of("subject", "ice cream flavors", "format", format)) + .build(); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = this.chatModel.call(prompt).getResult(); + + List list = outputConverter.convert(generation.getOutput().getText()); + assertThat(list).hasSize(5); + } + + @Test + void mapOutputConverter() { + MapOutputConverter outputConverter = new MapOutputConverter(); + + String format = outputConverter.getFormat(); + String template = """ + Provide me a List of {subject} + {format} + """; + PromptTemplate promptTemplate = PromptTemplate.builder() + .template(template) + .variables(Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", + format)) + .build(); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = this.chatModel.call(prompt).getResult(); + + Map result = outputConverter.convert(generation.getOutput().getText()); + assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); + + } + + @Test + void beanOutputConverterRecords() { + + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); + + String format = outputConverter.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + {format} + """; + PromptTemplate promptTemplate = PromptTemplate.builder() + .template(template) + .variables(Map.of("format", format)) + .build(); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = this.chatModel.call(prompt).getResult(); + + ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getText()); + logger.info("{}", actorsFilms); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + + @Test + void beanStreamOutputConverterRecords() { + + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); + + String format = outputConverter.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + {format} + """; + PromptTemplate promptTemplate = PromptTemplate.builder() + .template(template) + .variables(Map.of("format", format)) + .build(); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + + String generationTextFromStream = this.streamingChatModel.stream(prompt) + .collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getText) + .collect(Collectors.joining()); + + ActorsFilmsRecord actorsFilms = outputConverter.convert(generationTextFromStream); + logger.info("" + actorsFilms); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + + @Test + void functionCallTest() { + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? Response in Celsius"); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = CohereChatOptions.builder() + .model(CohereApi.ChatModel.COMMAND_A_R7B.getValue()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getText()).containsAnyOf("30.0", "30"); + assertThat(response.getMetadata()).isNotNull(); + assertThat(response.getMetadata().getUsage()).isNotNull(); + } + + @Test + void streamFunctionCallTest() { + + UserMessage userMessage = new UserMessage("What's the weather like in Tokyo, Japan? Response in Celsius"); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = CohereChatOptions.builder() + .model(CohereApi.ChatModel.COMMAND_A_R7B.getValue()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + + Flux response = this.streamingChatModel.stream(new Prompt(messages, promptOptions)); + + String content = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getText) + .collect(Collectors.joining()); + logger.info("Response: {}", content); + + assertThat(content).containsAnyOf("10.0", "10"); + } + + @Test + void streamFunctionCallUsageTest() { + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? Response in Celsius"); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = CohereChatOptions.builder() + .model(CohereApi.ChatModel.COMMAND_A_R7B.getValue()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + + Flux response = this.streamingChatModel.stream(new Prompt(messages, promptOptions)); + ChatResponse chatResponse = response.last().block(); + + logger.info("Response: {}", chatResponse); + assertThat(chatResponse.getMetadata()).isNotNull(); + assertThat(chatResponse.getMetadata().getUsage()).isNotNull(); + } + + @Test + void chatMemory() { + ChatMemory memory = MessageWindowChatMemory.builder().build(); + String conversationId = UUID.randomUUID().toString(); + + UserMessage userMessage1 = new UserMessage("My name is James Bond"); + memory.add(conversationId, userMessage1); + ChatResponse response1 = this.chatModel.call(new Prompt(memory.get(conversationId))); + + assertThat(response1).isNotNull(); + memory.add(conversationId, response1.getResult().getOutput()); + + UserMessage userMessage2 = new UserMessage("What is my name?"); + memory.add(conversationId, userMessage2); + ChatResponse response2 = this.chatModel.call(new Prompt(memory.get(conversationId))); + + assertThat(response2).isNotNull(); + memory.add(conversationId, response2.getResult().getOutput()); + + assertThat(response2.getResults()).hasSize(1); + assertThat(response2.getResult().getOutput().getText()).contains("James Bond"); + } + + @Test + void chatMemoryWithTools() { + ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder().build(); + ChatMemory chatMemory = MessageWindowChatMemory.builder().build(); + String conversationId = UUID.randomUUID().toString(); + + ChatOptions chatOptions = ToolCallingChatOptions.builder() + .toolCallbacks(ToolCallbacks.from(new MathTools())) + .internalToolExecutionEnabled(false) + .build(); + Prompt prompt = new Prompt( + List.of(new SystemMessage("You are a helpful assistant."), new UserMessage("What is 6 * 8?")), + chatOptions); + chatMemory.add(conversationId, prompt.getInstructions()); + + Prompt promptWithMemory = new Prompt(chatMemory.get(conversationId), chatOptions); + ChatResponse chatResponse = this.chatModel.call(promptWithMemory); + chatMemory.add(conversationId, chatResponse.getResult().getOutput()); + + while (chatResponse.hasToolCalls()) { + ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(promptWithMemory, + chatResponse); + chatMemory.add(conversationId, toolExecutionResult.conversationHistory() + .get(toolExecutionResult.conversationHistory().size() - 1)); + promptWithMemory = new Prompt(chatMemory.get(conversationId), chatOptions); + chatResponse = this.chatModel.call(promptWithMemory); + chatMemory.add(conversationId, chatResponse.getResult().getOutput()); + } + + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput().getText()).contains("48"); + + UserMessage newUserMessage = new UserMessage("What did I ask you earlier?"); + chatMemory.add(conversationId, newUserMessage); + + ChatResponse newResponse = this.chatModel.call(new Prompt(chatMemory.get(conversationId))); + + assertThat(newResponse).isNotNull(); + assertThat(newResponse.getResult().getOutput().getText()).contains("6").contains("8"); + } + + @Test + void streamingMultiModalityImageUrl() throws IOException { + + var userMessage = UserMessage.builder() + .text("Explain what do you see on this picture?") + .media(List.of(Media.builder() + .mimeType(MimeTypeUtils.IMAGE_PNG) + .data(URI.create("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png")) + .build())) + .build(); + + Flux response = this.streamingChatModel.stream(new Prompt(List.of(userMessage), + CohereChatOptions.builder().model(ChatModel.COMMAND_A_VISION.getValue()).build())); + + String content = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getText) + .collect(Collectors.joining()); + logger.info("Response: {}", content); + assertThat(content).containsAnyOf("bananas", "apple", "bowl", "basket", "fruit stand"); + } + + static class MathTools { + + @Tool(description = "Multiply the two numbers") + double multiply(double a, double b) { + return a * b; + } + + } + + record ActorsFilmsRecord(String actor, List movies) { + + } + +} diff --git a/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/chat/CohereChatModelObservationIT.java b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/chat/CohereChatModelObservationIT.java new file mode 100644 index 00000000000..8d5e72ee331 --- /dev/null +++ b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/chat/CohereChatModelObservationIT.java @@ -0,0 +1,185 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.chat; + +import java.util.List; + +import io.micrometer.common.KeyValue; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.cohere.api.CohereApi; +import org.springframework.ai.observation.conventions.AiOperationType; +import org.springframework.ai.observation.conventions.AiProvider; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.util.StringUtils; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; +import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.LowCardinalityKeyNames; + +/** + * Integration tests for observation instrumentation in {@link CohereChatModel}. + * + * @author Ricken Bazolo + */ +@SpringBootTest(classes = CohereChatModelObservationIT.Config.class) +@EnabledIfEnvironmentVariable(named = "COHERE_API_KEY", matches = ".+") +public class CohereChatModelObservationIT { + + @Autowired + TestObservationRegistry observationRegistry; + + @Autowired + CohereChatModel chatModel; + + @BeforeEach + void beforeEach() { + this.observationRegistry.clear(); + } + + @Test + void observationForChatOperation() { + var options = CohereChatOptions.builder() + .model(CohereApi.ChatModel.COMMAND_A_R7B.getValue()) + .maxTokens(2048) + .stop(List.of("this-is-the-end")) + .temperature(0.7) + .topP(1.0) + .build(); + + Prompt prompt = new Prompt("Why does a raven look like a desk?", options); + + ChatResponse chatResponse = this.chatModel.call(prompt); + assertThat(chatResponse.getResult().getOutput().getText()).isNotEmpty(); + + ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); + assertThat(responseMetadata).isNotNull(); + + validate(responseMetadata); + } + + @Test + void observationForStreamingChatOperation() { + var options = CohereChatOptions.builder() + .model(CohereApi.ChatModel.COMMAND_A_R7B.getValue()) + .maxTokens(2048) + .stop(List.of("this-is-the-end")) + .temperature(0.7) + .topP(1.0) + .build(); + + Prompt prompt = new Prompt("Why does a raven look like a desk?", options); + + Flux chatResponseFlux = this.chatModel.stream(prompt); + + List responses = chatResponseFlux.collectList().block(); + assertThat(responses).isNotEmpty(); + + // With MessageAggregator, all chunks are aggregated into a single response + // So we get the aggregated text from the last (or only) response + ChatResponse lastChatResponse = responses.get(responses.size() - 1); + String aggregatedResponse = lastChatResponse.getResult().getOutput().getText(); + assertThat(aggregatedResponse).isNotEmpty(); + + ChatResponseMetadata responseMetadata = lastChatResponse.getMetadata(); + assertThat(responseMetadata).isNotNull(); + + validate(responseMetadata); + } + + private void validate(ChatResponseMetadata responseMetadata) { + TestObservationRegistryAssert.assertThat(this.observationRegistry) + .doesNotHaveAnyRemainingCurrentObservation() + .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) + .that() + .hasContextualNameEqualTo("chat " + CohereApi.ChatModel.COMMAND_A_R7B.getValue()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), + AiOperationType.CHAT.value()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.COHERE.value()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(), + CohereApi.ChatModel.COMMAND_A_R7B.getValue()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), + StringUtils.hasText(responseMetadata.getModel()) ? responseMetadata.getModel() + : KeyValue.NONE_VALUE) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_MAX_TOKENS.asString(), "2048") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_STOP_SEQUENCES.asString(), + "[\"this-is-the-end\"]") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_TEMPERATURE.asString(), "0.7") + .doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.REQUEST_TOP_K.asString()) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_TOP_P.asString(), "1.0") + .matches(contextView -> { + var keyValue = contextView.getHighCardinalityKeyValues() + .stream() + .filter(tag -> tag.getKey().equals(HighCardinalityKeyNames.RESPONSE_ID.asString())) + .findFirst(); + if (StringUtils.hasText(responseMetadata.getId())) { + return keyValue.isPresent() && keyValue.get().getValue().equals(responseMetadata.getId()); + } + else { + return keyValue.isEmpty(); + } + }) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.RESPONSE_FINISH_REASONS.asString(), "[\"COMPLETE\"]") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getPromptTokens())) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getCompletionTokens())) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getTotalTokens())) + .hasBeenStarted() + .hasBeenStopped(); + } + + @SpringBootConfiguration + static class Config { + + @Bean + public TestObservationRegistry observationRegistry() { + return TestObservationRegistry.create(); + } + + @Bean + public CohereApi cohereApi() { + return CohereApi.builder().apiKey(System.getenv("COHERE_API_KEY")).build(); + } + + @Bean + public CohereChatModel cohereChatModel(CohereApi cohereApi, TestObservationRegistry observationRegistry) { + return CohereChatModel.builder() + .cohereApi(cohereApi) + .defaultOptions(CohereChatOptions.builder().build()) + .retryTemplate(new RetryTemplate()) + .observationRegistry(observationRegistry) + .build(); + } + + } + +} diff --git a/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/chat/CohereChatOptionsTests.java b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/chat/CohereChatOptionsTests.java new file mode 100644 index 00000000000..eb645d6a8ee --- /dev/null +++ b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/chat/CohereChatOptionsTests.java @@ -0,0 +1,249 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.chat; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.cohere.api.CohereApi; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link CohereChatOptions}. + * + * @author Ricken Bazolo + */ +class CohereChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + CohereChatOptions options = CohereChatOptions.builder() + .model("test-model") + .temperature(0.7) + .topP(0.9) + .maxTokens(100) + .seed(123) + .stop(List.of("stop1", "stop2")) + .toolChoice(CohereApi.ChatCompletionRequest.ToolChoice.REQUIRED) + .internalToolExecutionEnabled(true) + .toolContext(Map.of("key1", "value1")) + .build(); + + assertThat(options) + .extracting("model", "temperature", "topP", "maxTokens", "seed", "stop", "toolChoice", + "internalToolExecutionEnabled", "toolContext") + .containsExactly("test-model", 0.7, 0.9, 100, 123, List.of("stop1", "stop2"), + CohereApi.ChatCompletionRequest.ToolChoice.REQUIRED, true, Map.of("key1", "value1")); + } + + @Test + void testBuilderWithEnum() { + CohereChatOptions optionsWithEnum = CohereChatOptions.builder() + .model(CohereApi.ChatModel.COMMAND_A_R7B.getValue()) + .build(); + assertThat(optionsWithEnum.getModel()).isEqualTo(CohereApi.ChatModel.COMMAND_A_R7B.getValue()); + } + + @Test + void testCopy() { + CohereChatOptions options = CohereChatOptions.builder() + .model("test-model") + .temperature(0.7) + .topP(0.9) + .maxTokens(100) + .seed(123) + .stop(List.of("stop1", "stop2")) + .internalToolExecutionEnabled(true) + .toolContext(Map.of("key1", "value1")) + .build(); + + CohereChatOptions copiedOptions = options.copy(); + assertThat(copiedOptions).isNotSameAs(options).isEqualTo(options); + // Ensure deep copy + assertThat(copiedOptions.getStop()).isNotSameAs(options.getStop()); + assertThat(copiedOptions.getToolContext()).isNotSameAs(options.getToolContext()); + } + + @Test + void testSetters() { + CohereChatOptions options = new CohereChatOptions(); + options.setModel("test-model"); + options.setTemperature(0.7); + options.setTopP(0.9); + options.setMaxTokens(100); + options.setSeed(123); + options.setStopSequences(List.of("stop1", "stop2")); + + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getTemperature()).isEqualTo(0.7); + assertThat(options.getTopP()).isEqualTo(0.9); + assertThat(options.getMaxTokens()).isEqualTo(100); + assertThat(options.getSeed()).isEqualTo(123); + assertThat(options.getStopSequences()).isEqualTo(List.of("stop1", "stop2")); + } + + @Test + void testDefaultValues() { + CohereChatOptions options = new CohereChatOptions(); + assertThat(options.getModel()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getTopP()).isNull(); + assertThat(options.getMaxTokens()).isNull(); + assertThat(options.getSeed()).isNull(); + assertThat(options.getStopSequences()).isNull(); + } + + @Test + void testBuilderWithEmptyCollections() { + CohereChatOptions options = CohereChatOptions.builder() + .stop(Collections.emptyList()) + .toolContext(Collections.emptyMap()) + .build(); + + assertThat(options.getStop()).isEmpty(); + assertThat(options.getToolContext()).isEmpty(); + } + + @Test + void testBuilderWithBoundaryValues() { + CohereChatOptions options = CohereChatOptions.builder() + .temperature(0.0) + .topP(1.0) + .maxTokens(1) + .seed(Integer.MAX_VALUE) + .build(); + + assertThat(options.getTemperature()).isEqualTo(0.0); + assertThat(options.getTopP()).isEqualTo(1.0); + assertThat(options.getMaxTokens()).isEqualTo(1); + assertThat(options.getSeed()).isEqualTo(Integer.MAX_VALUE); + } + + @Test + void testBuilderWithSingleElementCollections() { + CohereChatOptions options = CohereChatOptions.builder() + .stop(List.of("single-stop")) + .toolContext(Map.of("single-key", "single-value")) + .build(); + + assertThat(options.getStop()).hasSize(1).containsExactly("single-stop"); + assertThat(options.getToolContext()).hasSize(1).containsEntry("single-key", "single-value"); + } + + @Test + void testCopyWithEmptyOptions() { + CohereChatOptions emptyOptions = new CohereChatOptions(); + CohereChatOptions copiedOptions = emptyOptions.copy(); + + assertThat(copiedOptions).isNotSameAs(emptyOptions).isEqualTo(emptyOptions); + assertThat(copiedOptions.getModel()).isNull(); + assertThat(copiedOptions.getTemperature()).isNull(); + } + + @Test + void testCopyMutationDoesNotAffectOriginal() { + CohereChatOptions original = CohereChatOptions.builder() + .model("original-model") + .temperature(0.5) + .stop(List.of("original-stop")) + .toolContext(Map.of("original", "value")) + .build(); + + CohereChatOptions copy = original.copy(); + copy.setModel("modified-model"); + copy.setTemperature(0.8); + + // Original should remain unchanged + assertThat(original.getModel()).isEqualTo("original-model"); + assertThat(original.getTemperature()).isEqualTo(0.5); + + // Copy should have new values + assertThat(copy.getModel()).isEqualTo("modified-model"); + assertThat(copy.getTemperature()).isEqualTo(0.8); + } + + @Test + void testEqualsAndHashCode() { + CohereChatOptions options1 = CohereChatOptions.builder().model("test-model").temperature(0.7).build(); + + CohereChatOptions options2 = CohereChatOptions.builder().model("test-model").temperature(0.7).build(); + + CohereChatOptions options3 = CohereChatOptions.builder().model("different-model").temperature(0.7).build(); + + assertThat(options1).isEqualTo(options2); + assertThat(options1.hashCode()).isEqualTo(options2.hashCode()); + + assertThat(options1).isNotEqualTo(options3); + assertThat(options1.hashCode()).isNotEqualTo(options3.hashCode()); + } + + @Test + void testAllToolChoiceEnumValues() { + for (CohereApi.ChatCompletionRequest.ToolChoice toolChoice : CohereApi.ChatCompletionRequest.ToolChoice + .values()) { + + CohereChatOptions options = CohereChatOptions.builder().toolChoice(toolChoice).build(); + + assertThat(options.getToolChoice()).isEqualTo(toolChoice); + } + } + + @Test + void testChainedBuilderMethods() { + CohereChatOptions options = CohereChatOptions.builder() + .model("test-model") + .temperature(0.7) + .topP(0.9) + .maxTokens(100) + .seed(123) + .internalToolExecutionEnabled(false) + .build(); + + // Verify all chained methods worked + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getTemperature()).isEqualTo(0.7); + assertThat(options.getTopP()).isEqualTo(0.9); + assertThat(options.getMaxTokens()).isEqualTo(100); + assertThat(options.getSeed()).isEqualTo(123); + assertThat(options.getInternalToolExecutionEnabled()).isFalse(); + } + + @Test + void testBuilderAndSetterConsistency() { + // Build an object using builder + CohereChatOptions builderOptions = CohereChatOptions.builder() + .model("test-model") + .temperature(0.7) + .topP(0.9) + .maxTokens(100) + .build(); + + // Create equivalent object using setters + CohereChatOptions setterOptions = new CohereChatOptions(); + setterOptions.setModel("test-model"); + setterOptions.setTemperature(0.7); + setterOptions.setTopP(0.9); + setterOptions.setMaxTokens(100); + + assertThat(builderOptions).isEqualTo(setterOptions); + } + +} diff --git a/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/embedding/CohereEmbeddingIT.java b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/embedding/CohereEmbeddingIT.java new file mode 100644 index 00000000000..fc68d35853c --- /dev/null +++ b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/embedding/CohereEmbeddingIT.java @@ -0,0 +1,90 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.embedding; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +import org.springframework.ai.cohere.CohereTestConfiguration; +import org.springframework.ai.cohere.api.CohereApi; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Ricken Bazolo + */ +@SpringBootTest(classes = CohereTestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "COHERE_API_KEY", matches = ".+") +class CohereEmbeddingIT { + + private static final int EMBED_DIMENSIONS = 384; + + @Autowired + private CohereApi cohereApi; + + @Autowired + private CohereEmbeddingModel cohereEmbeddingModel; + + @Test + void defaultEmbedding() { + var embeddingResponse = this.cohereEmbeddingModel.embedForResponse(List.of("Hello World")); + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0)).isNotNull(); + assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(EMBED_DIMENSIONS); + assertThat(this.cohereEmbeddingModel.dimensions()).isEqualTo(EMBED_DIMENSIONS); + } + + @ParameterizedTest + @CsvSource({ "embed-multilingual-light-v3.0, 384", "embed-english-light-v3.0, 384" }) + void defaultOptionsEmbedding(String model, int dimensions) { + var cohereEmbeddingOptions = CohereEmbeddingOptions.builder().model(model).build(); + var anotherCohereEmbeddingModel = CohereEmbeddingModel.builder() + .cohereApi(this.cohereApi) + .options(cohereEmbeddingOptions) + .build(); + var embeddingResponse = anotherCohereEmbeddingModel.embedForResponse(List.of("Hello World", "World is big")); + assertThat(embeddingResponse.getResults()).hasSize(2); + embeddingResponse.getResults().forEach(result -> { + assertThat(result).isNotNull(); + assertThat(result.getOutput()).hasSize(dimensions); + }); + assertThat(anotherCohereEmbeddingModel.dimensions()).isEqualTo(dimensions); + } + + @ParameterizedTest + @CsvSource({ "embed-multilingual-light-v3.0, 384", "embed-english-light-v3.0, 384" }) + void calledOptionsEmbedding(String model, int dimensions) { + var cohereEmbeddingOptions = CohereEmbeddingOptions.builder().model(model).build(); + var embeddingRequest = new EmbeddingRequest(List.of("Hello World", "World is big", "We are small"), + cohereEmbeddingOptions); + var embeddingResponse = this.cohereEmbeddingModel.call(embeddingRequest); + assertThat(embeddingResponse.getResults()).hasSize(3); + embeddingResponse.getResults().forEach(result -> { + assertThat(result).isNotNull(); + assertThat(result.getOutput()).hasSize(dimensions); + }); + assertThat(this.cohereEmbeddingModel.dimensions()).isEqualTo(EMBED_DIMENSIONS); + } + +} diff --git a/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/embedding/CohereEmbeddingModelObservationIT.java b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/embedding/CohereEmbeddingModelObservationIT.java new file mode 100644 index 00000000000..cf2226e265b --- /dev/null +++ b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/embedding/CohereEmbeddingModelObservationIT.java @@ -0,0 +1,118 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.embedding; + +import java.util.List; + +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import org.springframework.ai.cohere.api.CohereApi; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResponseMetadata; +import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention; +import org.springframework.ai.observation.conventions.AiOperationType; +import org.springframework.ai.observation.conventions.AiProvider; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.core.retry.RetryTemplate; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames; +import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.LowCardinalityKeyNames; + +/** + * Integration tests for observation instrumentation in {@link CohereEmbeddingModel}. + * + * @author Ricken Bazolo + */ +@SpringBootTest(classes = CohereEmbeddingModelObservationIT.Config.class) +@EnabledIfEnvironmentVariable(named = "COHERE_API_KEY", matches = ".+") +public class CohereEmbeddingModelObservationIT { + + @Autowired + TestObservationRegistry observationRegistry; + + @Autowired + CohereEmbeddingModel embeddingModel; + + @Test + void observationForEmbeddingOperation() { + var options = CohereEmbeddingOptions.builder() + .model(CohereApi.EmbeddingModel.EMBED_MULTILINGUAL_V3.getValue()) + .build(); + + EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options); + + EmbeddingResponse embeddingResponse = this.embeddingModel.call(embeddingRequest); + assertThat(embeddingResponse.getResults()).isNotEmpty(); + + EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata(); + assertThat(responseMetadata).isNotNull(); + + TestObservationRegistryAssert.assertThat(this.observationRegistry) + .doesNotHaveAnyRemainingCurrentObservation() + .hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME) + .that() + .hasContextualNameEqualTo("embedding " + CohereApi.EmbeddingModel.EMBED_MULTILINGUAL_V3.getValue()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), + AiOperationType.EMBEDDING.value()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.COHERE.value()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(), + CohereApi.EmbeddingModel.EMBED_MULTILINGUAL_V3.getValue()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), responseMetadata.getModel()) + .doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.REQUEST_EMBEDDING_DIMENSIONS.asString()) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getPromptTokens())) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getTotalTokens())) + .hasBeenStarted() + .hasBeenStopped(); + } + + @SpringBootConfiguration + static class Config { + + @Bean + public TestObservationRegistry observationRegistry() { + return TestObservationRegistry.create(); + } + + @Bean + public CohereApi cohereApi() { + return CohereApi.builder().apiKey(System.getenv("COHERE_API_KEY")).build(); + } + + @Bean + public CohereEmbeddingModel cohereEmbeddingModel(CohereApi cohereApi, + TestObservationRegistry observationRegistry) { + return CohereEmbeddingModel.builder() + .cohereApi(cohereApi) + .options(CohereEmbeddingOptions.builder().build()) + .retryTemplate(new RetryTemplate()) + .observationRegistry(observationRegistry) + .build(); + } + + } + +} diff --git a/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/embedding/CohereEmbeddingModelTests.java b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/embedding/CohereEmbeddingModelTests.java new file mode 100644 index 00000000000..56f29dab3d4 --- /dev/null +++ b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/embedding/CohereEmbeddingModelTests.java @@ -0,0 +1,155 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.embedding; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import org.springframework.ai.cohere.api.CohereApi; +import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.http.ResponseEntity; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +/** + * Unit tests for {@link CohereEmbeddingModel}. + * + * @author Ricken Bazolo + */ +class CohereEmbeddingModelTests { + + @Test + void testDimensionsForEmbedV4Model() { + CohereApi mockApi = createMockApiWithEmbeddingResponse(1024); + + CohereEmbeddingOptions options = CohereEmbeddingOptions.builder() + .model(CohereApi.EmbeddingModel.EMBED_V4.getValue()) + .build(); + + CohereEmbeddingModel model = CohereEmbeddingModel.builder() + .cohereApi(mockApi) + .metadataMode(MetadataMode.EMBED) + .options(options) + .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) + .build(); + + assertThat(model.dimensions()).isEqualTo(1536); + } + + @Test + void testDimensionsForMultilingualV3Model() { + CohereApi mockApi = createMockApiWithEmbeddingResponse(1024); + + CohereEmbeddingOptions options = CohereEmbeddingOptions.builder() + .model(CohereApi.EmbeddingModel.EMBED_MULTILINGUAL_V3.getValue()) + .build(); + + CohereEmbeddingModel model = CohereEmbeddingModel.builder() + .cohereApi(mockApi) + .metadataMode(MetadataMode.EMBED) + .options(options) + .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) + .build(); + + assertThat(model.dimensions()).isEqualTo(1024); + } + + @Test + void testDimensionsFallbackForUnknownModel() { + CohereApi mockApi = createMockApiWithEmbeddingResponse(512); + + // Use a model name that doesn't exist in KNOWN_EMBEDDING_DIMENSIONS + CohereEmbeddingOptions options = CohereEmbeddingOptions.builder().model("unknown-model").build(); + + CohereEmbeddingModel model = CohereEmbeddingModel.builder() + .cohereApi(mockApi) + .metadataMode(MetadataMode.EMBED) + .options(options) + .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) + .build(); + + // Should fall back to super.dimensions() which detects dimensions from the API + // response + assertThat(model.dimensions()).isEqualTo(1024); + } + + @Test + void testAllEmbeddingModelsHaveDimensionMapping() { + // This test ensures that KNOWN_EMBEDDING_DIMENSIONS map stays in sync with the + // EmbeddingModel enum + // If a new model is added to the enum but not to the dimensions map, this test + // will help catch it + + for (CohereApi.EmbeddingModel embeddingModel : CohereApi.EmbeddingModel.values()) { + CohereApi mockApi = createMockApiWithEmbeddingResponse(1024); + CohereEmbeddingOptions options = CohereEmbeddingOptions.builder().model(embeddingModel.getValue()).build(); + + CohereEmbeddingModel model = CohereEmbeddingModel.builder() + .cohereApi(mockApi) + .metadataMode(MetadataMode.EMBED) + .options(options) + .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) + .build(); + + // Each model should have a valid dimension (not the fallback -1) + assertThat(model.dimensions()).as("Model %s should have a dimension mapping", embeddingModel.getValue()) + .isGreaterThan(0); + } + } + + @Test + void testBuilderCreatesValidModel() { + CohereApi mockApi = createMockApiWithEmbeddingResponse(1024); + + CohereEmbeddingModel model = CohereEmbeddingModel.builder() + .cohereApi(mockApi) + .options(CohereEmbeddingOptions.builder() + .model(CohereApi.EmbeddingModel.EMBED_MULTILINGUAL_V3.getValue()) + .build()) + .build(); + + assertThat(model).isNotNull(); + assertThat(model.dimensions()).isEqualTo(1024); + } + + private CohereApi createMockApiWithEmbeddingResponse(int dimensions) { + CohereApi mockApi = Mockito.mock(CohereApi.class); + + // Create a mock embedding response with the specified dimensions + // Cohere returns List> for embeddings + List embedding = new java.util.ArrayList<>(dimensions); + for (int i = 0; i < dimensions; i++) { + embedding.add(0.1); + } + + // Cohere can return embeddings for multiple texts + List> embeddings = List.of(embedding); + + CohereApi.EmbeddingResponse embeddingResponse = new CohereApi.EmbeddingResponse("test-id", embeddings, + List.of("test text"), "embeddings_floats"); + + when(mockApi.embeddings(any())).thenReturn(ResponseEntity.ok(embeddingResponse)); + + return mockApi; + } + +} diff --git a/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/embedding/CohereMultimodalEmbeddingIT.java b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/embedding/CohereMultimodalEmbeddingIT.java new file mode 100644 index 00000000000..679e9f85874 --- /dev/null +++ b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/embedding/CohereMultimodalEmbeddingIT.java @@ -0,0 +1,107 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.embedding; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import org.springframework.ai.cohere.CohereTestConfiguration; +import org.springframework.ai.cohere.api.CohereApi; +import org.springframework.ai.content.Media; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.DocumentEmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResultMetadata; +import org.springframework.ai.embedding.EmbeddingResultMetadata.ModalityType; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.core.io.ClassPathResource; +import org.springframework.util.MimeTypeUtils; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Ricken Bazolo + */ +@SpringBootTest(classes = CohereTestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "COHERE_API_KEY", matches = ".+") +class CohereMultimodalEmbeddingIT { + + private static final int EMBED_DIMENSIONS = 1536; + + @Autowired + private CohereApi cohereApi; + + @Autowired + private CohereMultimodalEmbeddingModel cohereMultimodalEmbeddingModel; + + @Test + void imageEmbedding() { + + var document = Document.builder() + .media(new Media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.image.png"))) + .build(); + + DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(document); + + EmbeddingResponse embeddingResponse = this.cohereMultimodalEmbeddingModel.call(embeddingRequest); + assertThat(embeddingResponse.getResults()).hasSize(1); + + assertThat(embeddingResponse.getResults().get(0)).isNotNull(); + assertThat(embeddingResponse.getResults().get(0).getMetadata().getModalityType()).isEqualTo(ModalityType.TEXT); + assertThat(embeddingResponse.getResults().get(0).getMetadata().getMimeType()) + .isEqualTo(MimeTypeUtils.TEXT_PLAIN); + + assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(EMBED_DIMENSIONS); + + assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("embeddings_by_type"); + assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0); + + assertThat(this.cohereMultimodalEmbeddingModel.dimensions()).isEqualTo(EMBED_DIMENSIONS); + } + + @Test + void textAndImageEmbedding() { + + var textDocument = Document.builder().text("Hello World").build(); + + var imageDocument = Document.builder() + .media(new Media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.image.png"))) + .build(); + + DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(List.of(textDocument, imageDocument)); + + EmbeddingResponse embeddingResponse = this.cohereMultimodalEmbeddingModel.call(embeddingRequest); + assertThat(embeddingResponse.getResults()).hasSize(2); + assertThat(embeddingResponse.getResults().get(0)).isNotNull(); + assertThat(embeddingResponse.getResults().get(0).getMetadata().getModalityType()) + .isEqualTo(EmbeddingResultMetadata.ModalityType.TEXT); + assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(EMBED_DIMENSIONS); + + assertThat(embeddingResponse.getResults().get(1)).isNotNull(); + assertThat(embeddingResponse.getResults().get(1).getMetadata().getModalityType()).isEqualTo(ModalityType.TEXT); + assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(EMBED_DIMENSIONS); + + assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("embeddings_by_type"); + assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0); + + assertThat(this.cohereMultimodalEmbeddingModel.dimensions()).isEqualTo(EMBED_DIMENSIONS); + } + +} diff --git a/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/testutils/AbstractIT.java b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/testutils/AbstractIT.java new file mode 100644 index 00000000000..261dcd9231f --- /dev/null +++ b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/testutils/AbstractIT.java @@ -0,0 +1,101 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.cohere.testutils; + +import java.util.List; +import java.util.Map; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.cohere.api.CohereApi; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.core.io.Resource; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + +public abstract class AbstractIT { + + private static final Logger logger = LoggerFactory.getLogger(AbstractIT.class); + + @Autowired + protected ChatModel chatModel; + + @Autowired + protected CohereApi cohereApi; + + @Autowired + protected StreamingChatModel streamingChatModel; + + @Value("classpath:/prompts/eval/qa-evaluator-accurate-answer.st") + protected Resource qaEvaluatorAccurateAnswerResource; + + @Value("classpath:/prompts/eval/qa-evaluator-not-related-message.st") + protected Resource qaEvaluatorNotRelatedResource; + + @Value("classpath:/prompts/eval/qa-evaluator-fact-based-answer.st") + protected Resource qaEvaluatorFactBasedAnswerResource; + + @Value("classpath:/prompts/eval/user-evaluator-message.st") + protected Resource userEvaluatorResource; + + @Value("classpath:/prompts/system-message.st") + protected Resource systemResource; + + protected void evaluateQuestionAndAnswer(String question, ChatResponse response, boolean factBased) { + assertThat(response).isNotNull(); + String answer = response.getResult().getOutput().getText(); + logger.info("Question: {}", question); + logger.info("Answer:{}", answer); + PromptTemplate userPromptTemplate = PromptTemplate.builder() + .resource(this.userEvaluatorResource) + .variables(Map.of("question", question, "answer", answer)) + .build(); + SystemMessage systemMessage; + if (factBased) { + systemMessage = new SystemMessage(this.qaEvaluatorFactBasedAnswerResource); + } + else { + systemMessage = new SystemMessage(this.qaEvaluatorAccurateAnswerResource); + } + Message userMessage = userPromptTemplate.createMessage(); + Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + String yesOrNo = this.chatModel.call(prompt).getResult().getOutput().getText(); + logger.info("Is Answer related to question: {}", yesOrNo); + assert yesOrNo != null; + if (yesOrNo.equalsIgnoreCase("no")) { + SystemMessage notRelatedSystemMessage = new SystemMessage(this.qaEvaluatorNotRelatedResource); + prompt = new Prompt(List.of(userMessage, notRelatedSystemMessage)); + String reasonForFailure = this.chatModel.call(prompt).getResult().getOutput().getText(); + fail(reasonForFailure); + } + else { + logger.info("Answer is related to question."); + assertThat(yesOrNo).isEqualTo("YES"); + } + } + +} diff --git a/models/spring-ai-cohere/src/test/resources/prompts/eval/qa-evaluator-accurate-answer.st b/models/spring-ai-cohere/src/test/resources/prompts/eval/qa-evaluator-accurate-answer.st new file mode 100644 index 00000000000..56270359545 --- /dev/null +++ b/models/spring-ai-cohere/src/test/resources/prompts/eval/qa-evaluator-accurate-answer.st @@ -0,0 +1,3 @@ +You are an AI assistant who helps users to evaluate if the answers to questions are accurate. +You will be provided with a QUESTION and an ANSWER. +Your goal is to evaluate the QUESTION and ANSWER and reply with a YES or NO answer. \ No newline at end of file diff --git a/models/spring-ai-cohere/src/test/resources/prompts/eval/qa-evaluator-fact-based-answer.st b/models/spring-ai-cohere/src/test/resources/prompts/eval/qa-evaluator-fact-based-answer.st new file mode 100644 index 00000000000..22fc3e88d14 --- /dev/null +++ b/models/spring-ai-cohere/src/test/resources/prompts/eval/qa-evaluator-fact-based-answer.st @@ -0,0 +1,7 @@ +You are an AI evaluator. Your task is to verify if the provided ANSWER is a direct and accurate response to the given QUESTION. If the ANSWER is correct and directly answers the QUESTION, reply with "YES". If the ANSWER is not a direct response or is inaccurate, reply with "NO". + +For example: + +If the QUESTION is "What is the capital of France?" and the ANSWER is "Paris.", you should respond with "YES". +If the QUESTION is "What is the capital of France?" and the ANSWER is "France is in Europe.", respond with "NO". +Now, evaluate the following: diff --git a/models/spring-ai-cohere/src/test/resources/prompts/eval/qa-evaluator-not-related-message.st b/models/spring-ai-cohere/src/test/resources/prompts/eval/qa-evaluator-not-related-message.st new file mode 100644 index 00000000000..7c33e675e02 --- /dev/null +++ b/models/spring-ai-cohere/src/test/resources/prompts/eval/qa-evaluator-not-related-message.st @@ -0,0 +1,4 @@ +You are an AI assistant who helps users to evaluate if the answers to questions are accurate. +You will be provided with a QUESTION and an ANSWER. +A previous evaluation has determined that QUESTION and ANSWER are not related. +Give an explanation as to why they are not related. \ No newline at end of file diff --git a/models/spring-ai-cohere/src/test/resources/prompts/eval/user-evaluator-message.st b/models/spring-ai-cohere/src/test/resources/prompts/eval/user-evaluator-message.st new file mode 100644 index 00000000000..b3fa3e902d2 --- /dev/null +++ b/models/spring-ai-cohere/src/test/resources/prompts/eval/user-evaluator-message.st @@ -0,0 +1,6 @@ +The question and answer to evaluate are: + +QUESTION: ```{question}``` + +ANSWER: ```{answer}``` + diff --git a/models/spring-ai-cohere/src/test/resources/prompts/system-message.st b/models/spring-ai-cohere/src/test/resources/prompts/system-message.st new file mode 100644 index 00000000000..579febd8d9b --- /dev/null +++ b/models/spring-ai-cohere/src/test/resources/prompts/system-message.st @@ -0,0 +1,3 @@ +You are an AI assistant that helps people find information. +Your name is {name}. +You should reply to the user's request with your name and also in the style of a {voice}. \ No newline at end of file diff --git a/models/spring-ai-cohere/src/test/resources/test.image.png b/models/spring-ai-cohere/src/test/resources/test.image.png new file mode 100644 index 00000000000..8abb4c81aea Binary files /dev/null and b/models/spring-ai-cohere/src/test/resources/test.image.png differ diff --git a/pom.xml b/pom.xml index 9695f90f231..c6f5961a314 100644 --- a/pom.xml +++ b/pom.xml @@ -118,6 +118,7 @@ auto-configurations/models/spring-ai-autoconfigure-model-google-genai auto-configurations/models/spring-ai-autoconfigure-model-zhipuai auto-configurations/models/spring-ai-autoconfigure-model-deepseek + auto-configurations/models/spring-ai-autoconfigure-model-cohere auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient @@ -189,6 +190,7 @@ models/spring-ai-google-genai-embedding models/spring-ai-zhipuai models/spring-ai-deepseek + models/spring-ai-cohere spring-ai-spring-boot-starters/spring-ai-starter-model-anthropic spring-ai-spring-boot-starters/spring-ai-starter-model-azure-openai @@ -210,6 +212,7 @@ spring-ai-spring-boot-starters/spring-ai-starter-model-vertex-ai-gemini spring-ai-spring-boot-starters/spring-ai-starter-model-zhipuai spring-ai-spring-boot-starters/spring-ai-starter-model-deepseek + spring-ai-spring-boot-starters/spring-ai-starter-model-cohere spring-ai-spring-boot-starters/spring-ai-starter-model-chat-memory spring-ai-spring-boot-starters/spring-ai-starter-model-chat-memory-repository-cassandra @@ -842,6 +845,7 @@ org.springframework.ai.vertexai.embedding/**/*IT.java org.springframework.ai.vertexai.gemini/**/*IT.java org.springframework.ai.zhipuai/**/*IT.java + org.springframework.ai.cohere/**/*IT.java org.springframework.ai.vectorstore**/CosmosDB**IT.java diff --git a/spring-ai-bom/pom.xml b/spring-ai-bom/pom.xml index 53b3d6aa0d5..9b046e531b0 100644 --- a/spring-ai-bom/pom.xml +++ b/spring-ai-bom/pom.xml @@ -353,6 +353,12 @@ ${project.version} + + org.springframework.ai + spring-ai-cohere + ${project.version} + + @@ -717,6 +723,12 @@ ${project.version} + + org.springframework.ai + spring-ai-autoconfigure-model-cohere + ${project.version} + + org.springframework.ai @@ -1084,6 +1096,12 @@ ${project.version} + + org.springframework.ai + spring-ai-starter-model-cohere + ${project.version} + + diff --git a/spring-ai-commons/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java b/spring-ai-commons/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java index 88105725a69..49f460a4bc3 100644 --- a/spring-ai-commons/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java +++ b/spring-ai-commons/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java @@ -45,6 +45,11 @@ public enum AiProvider { */ BEDROCK_CONVERSE("bedrock_converse"), + /** + * AI system provided by Cohere. + */ + COHERE("cohere"), + /** * AI system provided by DeepSeek. */ diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index d4b40d8dd3f..fd289524caa 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -17,6 +17,7 @@ **** xref:api/chat/bedrock-converse.adoc[Amazon Bedrock Converse] **** xref:api/chat/anthropic-chat.adoc[Anthropic] **** xref:api/chat/azure-openai-chat.adoc[Azure OpenAI] +**** xref:api/chat/cohere-chat.adoc[Cohere] **** xref:api/chat/deepseek-chat.adoc[DeepSeek] **** xref:api/chat/dmr-chat.adoc[Docker Model Runner] **** Google @@ -56,6 +57,9 @@ ***** xref:api/embeddings/vertexai-embeddings-text.adoc[Text Embedding] ***** xref:api/embeddings/vertexai-embeddings-multimodal.adoc[Multimodal Embedding] **** xref:api/embeddings/zhipuai-embeddings.adoc[ZhiPu AI] +**** Cohere +***** xref:api/embeddings/cohere-embeddings-text.adoc[Text Embedding] +***** xref:api/embeddings/cohere-embeddings-multimodal.adoc[Multimodal Embedding] *** xref:api/imageclient.adoc[Image Models] **** xref:api/image/azure-openai-image.adoc[Azure OpenAI] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/cohere-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/cohere-chat.adoc new file mode 100644 index 00000000000..8bd94879ea5 --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/cohere-chat.adoc @@ -0,0 +1,339 @@ += Cohere Chat + +Spring AI supports the various AI language models from Cohere. You can interact with Cohere language models and create multilingual conversational assistants based on Cohere's powerful models. + +== Prerequisites + +You will need to create an API key with Cohere to access Cohere language models. + +Create an account at https://dashboard.cohere.com/welcome/register[Cohere registration page] and generate the token on the https://dashboard.cohere.com/api-keys[API Keys page]. + +The Spring AI project defines a configuration property named `spring.ai.cohere.api-key` that you should set to the value of the API Key obtained from dashboard.cohere.com. + +You can set this configuration property in your `application.properties` file: + +[source,properties] +---- +spring.ai.cohere.api-key= +---- + +Alternatively, you can set this as an environment variable: + +[source,bash] +---- +export COHERE_API_KEY= +---- + +=== Add Repositories and BOM + +Spring AI artifacts are published in Maven Central and Spring Snapshot repositories. +Refer to the xref:getting-started.adoc#artifact-repositories[Artifact Repositories] section to add these repositories to your build system. + +To help with dependency management, Spring AI provides a BOM (bill of materials) to ensure that a consistent version of Spring AI is used throughout the entire project. Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build system. + +== Auto-configuration + +[NOTE] +==== +There has been a significant change in the Spring AI auto-configuration, starter modules' artifact names. +Please refer to the https://docs.spring.io/spring-ai/reference/upgrade-notes.html[upgrade notes] for more information. +==== + +Spring AI provides Spring Boot auto-configuration for the Cohere Chat Client. +To enable it add the following dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-starter-model-cohere + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-starter-model-cohere' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +=== Chat Properties + +==== Retry Properties + +The prefix `spring.ai.retry` is used as the property prefix that lets you configure the retry mechanism for the Cohere chat model. + +[cols="3,5,1", stripes=even] +|==== +| Property | Description | Default + +| spring.ai.retry.max-attempts | Maximum number of retry attempts. | 10 +| spring.ai.retry.backoff.initial-interval | Initial sleep duration for the exponential backoff policy. | 2 sec. +| spring.ai.retry.backoff.multiplier | Backoff interval multiplier. | 5 +| spring.ai.retry.backoff.max-interval | Maximum backoff duration. | 3 min. +| spring.ai.retry.on-client-errors | If false, throw a NonTransientAiException, and do not attempt retry for `4xx` client error codes | false +| spring.ai.retry.exclude-on-http-codes | List of HTTP status codes that should not trigger a retry (e.g. to throw NonTransientAiException). | empty +| spring.ai.retry.on-http-codes | List of HTTP status codes that should trigger a retry (e.g. to throw TransientAiException). | empty +|==== + +==== Connection Properties + +The prefix `spring.ai.cohere` is used as the property prefix that lets you connect to Cohere. + +[cols="3,5,1", stripes=even] +|==== +| Property | Description | Default + +| spring.ai.cohere.base-url | The URL to connect to | https://api.cohere.com +| spring.ai.cohere.api-key | The API Key | - +|==== + +==== Configuration Properties + +[NOTE] +==== +Enabling and disabling of the chat auto-configurations are now configured via top level properties with the prefix `spring.ai.model.chat`. + +To enable, spring.ai.model.chat=cohere (It is enabled by default) + +To disable, spring.ai.model.chat=none (or any value which doesn't match cohere) + +This change is done to allow configuration of multiple models. +==== + +The prefix `spring.ai.cohere.chat` is the property prefix that lets you configure the chat model implementation for Cohere. + +[cols="3,5,1", stripes=even] +|==== +| Property | Description | Default + +| spring.ai.cohere.chat.enabled (Removed and no longer valid) | Enable Cohere chat model. | true +| spring.ai.model.chat | Enable Cohere chat model. | cohere +| spring.ai.cohere.chat.base-url | Optional override for the `spring.ai.cohere.base-url` property to provide chat-specific URL. | - +| spring.ai.cohere.chat.api-key | Optional override for the `spring.ai.cohere.api-key` to provide chat-specific API Key. | - +| spring.ai.cohere.chat.options.model | This is the Cohere Chat model to use | `command-r7b-12-2024` (see available models below) +| spring.ai.cohere.chat.options.temperature | The sampling temperature to use that controls the apparent creativity of generated completions. Higher values will make output more random while lower values will make results more focused and deterministic. It is not recommended to modify `temperature` and `p` for the same completions request as the interaction of these two settings is difficult to predict. | 0.3 +| spring.ai.cohere.chat.options.max-tokens | The maximum number of tokens to generate in the chat completion. The total length of input tokens and generated tokens is limited by the model's context length. | - +| spring.ai.cohere.chat.options.p | Ensures that only the most likely tokens, with total probability mass of p, are considered for generation at each step. If both k and p are enabled, p acts after k. min value of 0.01, max value of 0.99. | 1.0 +| spring.ai.cohere.chat.options.k | Ensures that only the top k most likely tokens are considered for generation at each step. When k is set to 0, k-sampling is disabled. min value of 0, max value of 500. | 0 +| spring.ai.cohere.chat.options.frequency-penalty | Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. Min value of 0.0, max value of 1.0. | 0.0 +| spring.ai.cohere.chat.options.presence-penalty | Used to reduce repetitiveness of generated tokens. Similar to frequency_penalty, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. Min value of 0.0, max value of 1.0. | 0.0 +| spring.ai.cohere.chat.options.seed | If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. | - +| spring.ai.cohere.chat.options.stop-sequences | A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens. | - +| spring.ai.cohere.chat.options.response-format | An object specifying the format that the model must output. Setting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.| - +| spring.ai.cohere.chat.options.safety-mode | Used to select the safety instruction inserted into the prompt. Can be OFF, CONTEXTUAL, or STRICT. When OFF is specified, the safety instruction will be omitted. | CONTEXTUAL +| spring.ai.cohere.chat.options.logprobs | When set to true, the log probabilities of the generated tokens will be included in the response. | false +| spring.ai.cohere.chat.options.strict-tools | When enabled, tool calls are validated against the tool JSON schemas. | - +| spring.ai.cohere.chat.options.tools | A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. | - +| spring.ai.cohere.chat.options.tool-choice | Controls which (if any) function is called by the model. `none` means the model will not call a function and instead generates a message. `required` means the model can pick between generating a message or calling a function. Specifying a particular function via `{"type: "function", "function": {"name": "my_function"}}` forces the model to call that function. `required` is the default when no functions are present. `required` is the default if functions are present. | - +| spring.ai.cohere.chat.options.tool-names | List of tools, identified by their names, to enable for function calling in a single prompt request. Tools with those names must exist in the ToolCallback registry. | - +| spring.ai.cohere.chat.options.tool-callbacks | Tool Callbacks to register with the ChatModel. | - +| spring.ai.cohere.chat.options.internal-tool-execution-enabled | If false, the Spring AI will not handle the tool calls internally, but will proxy them to the client. Then it is the client's responsibility to handle the tool calls, dispatch them to the appropriate function, and return the results. If true (the default), the Spring AI will handle the function calls internally. Applicable only for chat models with function calling support | true +|==== + +NOTE: You can override the common `spring.ai.cohere.base-url` and `spring.ai.cohere.api-key` for the `ChatModel` and `EmbeddingModel` implementations. +The `spring.ai.cohere.chat.base-url` and `spring.ai.cohere.chat.api-key` properties, if set, take precedence over the common properties. +This is useful if you want to use different Cohere accounts for different models and different model endpoints. + +TIP: All properties prefixed with `spring.ai.cohere.chat.options` can be overridden at runtime by adding request-specific <> to the `Prompt` call. + +== Available Models + +Cohere provides several chat models, each optimized for different use cases: + +[cols="2,1,4", stripes=even] +|==== +| Model | Context Length | Description + +| `command-a-03-2025` +| 128K tokens +| Latest flagship model with enhanced reasoning capabilities. Best overall performance for complex tasks. + +| `command-a-reasoning-08-2025` +| 128K tokens +| Specialized model optimized for reasoning tasks, mathematical problem-solving, and logical deduction. + +| `command-a-translate-08-2025` +| 128K tokens +| Optimized for translation tasks across multiple languages. Provides high-quality translations. + +| `command-a-vision-07-2025` +| 128K tokens +| Multimodal model with vision capabilities. Can process and understand images along with text. + +| `command-r7b-12-2024` +| 128K tokens +| Lightweight 7 billion parameter model. Faster and more cost-effective while maintaining good quality. Default model. + +| `command-r-plus-08-2024` +| 128K tokens +| Enhanced version of Command R with improved performance and multilingual capabilities. + +| `command-r-08-2024` +| 128K tokens +| General-purpose model with strong multilingual support and retrieval-augmented generation capabilities. +|==== + + +== Runtime Options [[chat-options]] + +The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/chat/CohereChatOptions.java[CohereChatOptions.java] provides model configurations, such as the model to use, the temperature, the frequency penalty, etc. + +On start-up, the default options can be configured with the `CohereChatModel(api, options)` constructor or the `spring.ai.cohere.chat.options.*` properties. + +At run-time, you can override the default options by adding new, request-specific options to the `Prompt` call. +For example, to override the default model and temperature for a specific request: + +[source,java] +---- +ChatResponse response = chatModel.call( + new Prompt( + "Generate the names of 5 famous pirates.", + CohereChatOptions.builder() + .model(CohereApi.ChatModel.COMMAND_A.getName()) + .temperature(0.5) + .build() + )); +---- + +TIP: In addition to the model specific link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/chat/CohereChatOptions.java[CohereChatOptions] you can use a portable link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java[ChatOptions] instance, created with link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java[ChatOptions#builder()]. + +== Function Calling + +You can register custom Java functions with the `CohereChatModel` and have the Cohere model intelligently choose to output a JSON object containing arguments to call one or many of the registered functions. +This is a powerful technique to connect the LLM capabilities with external tools and APIs. +Read more about xref:api/tools.adoc[Tool Calling]. + +== Multimodal + +Multimodality refers to a model's ability to simultaneously understand and process information from various sources, including text, images, audio, and other data formats. +Cohere supports text and vision modalities. + +=== Vision + +Cohere models that offer vision multimodal support include `command-a-vision-07-2025`. +Refer to the link:https://docs.cohere.com/docs/vision[Vision] guide for more information. + +The Cohere link:https://docs.cohere.com/reference/chat[Chat API] can incorporate a list of base64-encoded images or image urls with the message. +Spring AI's link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/messages/Message.java[Message] interface facilitates multimodal AI models by introducing the link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-commons/src/main/java/org/springframework/ai/content/Media.java[Media] type. +This type encompasses data and details regarding media attachments in messages, utilizing Spring's `org.springframework.util.MimeType` and a `org.springframework.core.io.Resource` for the raw media data. + +Below is a code example illustrating the fusion of user text with an image: + +[source,java] +---- +var imageResource = new ClassPathResource("/multimodal.test.png"); + +var userMessage = new UserMessage("Explain what do you see on this picture?", + new Media(MimeTypeUtils.IMAGE_PNG, imageResource)); + +ChatResponse response = chatModel.call(new Prompt(userMessage, + ChatOptions.builder().model(CohereApi.ChatModel.COMMAND_A_VISION.getName()).build())); +---- + +TIP: You can pass multiple images as well. + +== Sample Controller (Auto-configuration) + +https://start.spring.io/[Create] a new Spring Boot project and add the `spring-ai-starter-model-cohere` to your pom (or gradle) dependencies. + +Add a `application.properties` file under the `src/main/resources` directory to enable and configure the Cohere chat model: + +[source,application.properties] +---- +spring.ai.cohere.api-key=YOUR_API_KEY +spring.ai.cohere.chat.options.model=command-r7b-12-2024 +spring.ai.cohere.chat.options.temperature=0.7 +---- + +TIP: Replace the `api-key` with your Cohere credentials. + +This will create a `CohereChatModel` implementation that you can inject into your classes. +Here is an example of a simple `@RestController` class that uses the chat model for text generations. + +[source,java] +---- +@RestController +public class ChatController { + + private final CohereChatModel chatModel; + + @Autowired + public ChatController(CohereChatModel chatModel) { + this.chatModel = chatModel; + } + + @GetMapping("/ai/generate") + public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + return Map.of("generation", this.chatModel.call(message)); + } + + @GetMapping("/ai/generateStream") + public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + var prompt = new Prompt(new UserMessage(message)); + return this.chatModel.stream(prompt); + } +} +---- + +== Manual Configuration + +The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/chat/CohereChatModel.java[CohereChatModel] implements the `ChatModel` and `StreamingChatModel` and uses the <> to connect to the Cohere service. + +Add the `spring-ai-cohere` dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-cohere + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-cohere' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +Next, create a `CohereChatModel` and use it for text generations: + +[source,java] +---- +var cohereApi = new CohereApi(System.getenv("COHERE_API_KEY")); +var chatModel = new CohereChatModel(cohereApi, CohereChatOptions.builder() + .model(CohereApi.ChatModel.COMMAND_A.getName()) + .temperature(0.4) + .build()); + +ChatResponse response = chatModel.call(new Prompt("Generate the names of 5 famous pirates.")); +---- + +=== Low-level CohereApi Client [[low-level-api]] + +The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/api/CohereApi.java[CohereApi] provides a lightweight Java client for link:https://docs.cohere.com/reference/chat[Cohere API]. + +Here is a simple snippet showing how to use the API programmatically: + +[source,java] +---- +CohereApi cohereApi = new CohereApi(System.getenv("COHERE_API_KEY")); +ChatCompletionMessage message = new ChatCompletionMessage("Hello world", Role.USER); +ResponseEntity response = cohereApi.chatCompletionEntity( + new ChatCompletionRequest(List.of(message), CohereApi.ChatModel.COMMAND_A.getName(), 0.8, false)); +---- + +==== CohereApi Samples + +* The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/api/CohereApiIT.java[CohereApiIT.java] tests provide some general examples of how to use the lightweight library. + +* The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/chat/CohereChatModelIT.java[CohereChatModelIT.java] tests show examples of using function calling and streaming. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/cohere-embeddings-multimodal.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/cohere-embeddings-multimodal.adoc new file mode 100644 index 00000000000..464ac8b696b --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/cohere-embeddings-multimodal.adoc @@ -0,0 +1,351 @@ += Cohere Multimodal Embeddings + +Cohere supports two types of embeddings models, text and multimodal. +This document describes how to create multimodal embeddings using the Cohere link:https://docs.cohere.com/docs/embeddings[Multimodal embeddings API]. + +The multimodal embeddings model generates 1536-dimension vectors based on the input you provide, which can include a combination of image and text data. +The embedding vectors can then be used for subsequent tasks like image classification or visual search. + +The image embedding vector and text embedding vector are in the same semantic space with the same dimensionality. +Consequently, these vectors can be used interchangeably for use cases like searching images by text, or searching text by image. + +NOTE: The Cohere Multimodal API imposes the following limits: maximum 1 image per request, maximum 5MB per image, supported formats are JPEG, PNG, WebP, and GIF. + +TIP: For text-only embedding use cases, we recommend using the xref:api/embeddings/cohere-embeddings-text.adoc[Cohere text-embeddings model] instead. + +== Prerequisites + +You will need to create an API key with Cohere to access Cohere embedding models. + +Create an account at https://dashboard.cohere.com/welcome/register[Cohere registration page] and generate the token on the https://dashboard.cohere.com/api-keys[API Keys page]. + +The Spring AI project defines a configuration property named `spring.ai.cohere.api-key` that you should set to the value of the API Key obtained from dashboard.cohere.com. + +You can set this configuration property in your `application.properties` file: + +[source,properties] +---- +spring.ai.cohere.api-key= +---- + +For enhanced security when handling sensitive information like API keys, you can use Spring Expression Language (SpEL) to reference an environment variable: + +[source,yaml] +---- +# In application.yml +spring: + ai: + cohere: + api-key: ${COHERE_API_KEY} +---- + +[source,bash] +---- +# In your environment or .env file +export COHERE_API_KEY= +---- + +You can also set this configuration programmatically in your application code: + +[source,java] +---- +// Retrieve API key from a secure source or environment variable +String apiKey = System.getenv("COHERE_API_KEY"); +---- + +=== Add Repositories and BOM + +Spring AI artifacts are published in Maven Central and Spring Snapshot repositories. +Refer to the xref:getting-started.adoc#artifact-repositories[Artifact Repositories] section to add these repositories to your build system. + +To help with dependency management, Spring AI provides a BOM (bill of materials) to ensure that a consistent version of Spring AI is used throughout the entire project. Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build system. + + +== Auto-configuration + +[NOTE] +==== +There has been a significant change in the Spring AI auto-configuration, starter modules' artifact names. +Please refer to the https://docs.spring.io/spring-ai/reference/upgrade-notes.html[upgrade notes] for more information. +==== + +Spring AI provides Spring Boot auto-configuration for the Cohere Multimodal Embedding Model. +To enable it add the following dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-starter-model-cohere + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-starter-model-cohere' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +=== Embedding Properties + +==== Retry Properties + +The prefix `spring.ai.retry` is used as the property prefix that lets you configure the retry mechanism for the Cohere Multimodal Embedding model. + +[cols="3,5,1", stripes=even] +|==== +| Property | Description | Default + +| spring.ai.retry.max-attempts | Maximum number of retry attempts. | 10 +| spring.ai.retry.backoff.initial-interval | Initial sleep duration for the exponential backoff policy. | 2 sec. +| spring.ai.retry.backoff.multiplier | Backoff interval multiplier. | 5 +| spring.ai.retry.backoff.max-interval | Maximum backoff duration. | 3 min. +| spring.ai.retry.on-client-errors | If false, throw a NonTransientAiException, and do not attempt retry for `4xx` client error codes | false +| spring.ai.retry.exclude-on-http-codes | List of HTTP status codes that should not trigger a retry (e.g. to throw NonTransientAiException). | empty +| spring.ai.retry.on-http-codes | List of HTTP status codes that should trigger a retry (e.g. to throw TransientAiException). | empty +|==== + +==== Connection Properties + +The prefix `spring.ai.cohere` is used as the property prefix that lets you connect to Cohere. + +[cols="3,5,1", stripes=even] +|==== +| Property | Description | Default + +| spring.ai.cohere.base-url | The URL to connect to | https://api.cohere.com +| spring.ai.cohere.api-key | The API Key | - +|==== + +==== Configuration Properties + +[NOTE] +==== +Enabling and disabling of the multimodal embedding auto-configurations are now configured via top level properties with the prefix `spring.ai.model.embedding.multimodal`. + +To enable, spring.ai.model.embedding.multimodal=cohere (It is enabled by default) + +To disable, spring.ai.model.embedding.multimodal=none (or any value which doesn't match cohere) + +This change is done to allow configuration of multiple models. +==== + +The prefix `spring.ai.cohere.embedding.multimodal` is the property prefix that configures the multimodal `EmbeddingModel` implementation for Cohere. + +[cols="3,5,1", stripes=even] +|==== +| Property | Description | Default + +| spring.ai.model.embedding.multimodal | Enable Cohere multimodal embedding model. | cohere +| spring.ai.cohere.embedding.multimodal.base-url | Optional overrides the spring.ai.cohere.base-url to provide embedding specific url | - +| spring.ai.cohere.embedding.multimodal.api-key | Optional overrides the spring.ai.cohere.api-key to provide embedding specific api-key | - +| spring.ai.cohere.embedding.multimodal.options.model | The model to use | embed-v4 +| spring.ai.cohere.embedding.multimodal.options.input-type | The type of input (search_document, search_query, classification, clustering) | classification +| spring.ai.cohere.embedding.multimodal.options.embedding-types | The types of embeddings to return (float, int8, uint8, binary, ubinary) | [float] +| spring.ai.cohere.embedding.multimodal.options.truncate | How to handle inputs longer than maximum token length (NONE, START, END) | - +|==== + +NOTE: You can override the common `spring.ai.cohere.base-url` and `spring.ai.cohere.api-key` for the `ChatModel` and `EmbeddingModel` implementations. +The `spring.ai.cohere.embedding.multimodal.base-url` and `spring.ai.cohere.embedding.multimodal.api-key` properties if set take precedence over the common properties. +Similarly, the `spring.ai.cohere.chat.base-url` and `spring.ai.cohere.chat.api-key` properties if set take precedence over the common properties. +This is useful if you want to use different Cohere accounts for different models and different model endpoints. + +TIP: All properties prefixed with `spring.ai.cohere.embedding.multimodal.options` can be overridden at runtime by adding a request specific <> to the `DocumentEmbeddingRequest` call. + +== Runtime Options [[embedding-options]] + +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereMultimodalEmbeddingOptions.java[CohereMultimodalEmbeddingOptions.java] provides the Cohere multimodal configurations, such as the model to use and etc. + +The default options can be configured using the `spring.ai.cohere.embedding.multimodal.options` properties as well. + +At start-time use the `CohereMultimodalEmbeddingModel` constructor to set the default options used for all embedding requests. +At run-time you can override the default options, using a `CohereMultimodalEmbeddingOptions` instance as part of your `DocumentEmbeddingRequest`. + +For example to override the default model name and input type for a specific request: + +[source,java] +---- +// Create a document with text only +Document textDocument = new Document("Hello World"); + +// Create a document with an image +Media imageMedia = new Media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.image.png")); +Document imageDocument = new Document("", List.of(imageMedia), Map.of()); + +// Create a document with both text and image +Document multimodalDocument = new Document("Describe this image", List.of(imageMedia), Map.of()); + +// Create embedding request with custom options +DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest( + List.of(textDocument, imageDocument, multimodalDocument), + CohereMultimodalEmbeddingOptions.builder() + .model("embed-v4") + .inputType(InputType.CLASSIFICATION) + .build()); + +EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest); +---- + +== Understanding Input Types + +Cohere embeddings support different input types to optimize the embeddings for specific use cases: + +* `SEARCH_DOCUMENT`: Use when embedding documents to be retrieved in a search system +* `SEARCH_QUERY`: Use when embedding search queries to match against documents +* `CLASSIFICATION`: Use for classification tasks (text or image classification) +* `CLUSTERING`: Use for clustering documents or images by similarity + +For best results in semantic search applications, use `SEARCH_DOCUMENT` for your corpus (both text and images) and `SEARCH_QUERY` for user queries. + +== Image Format Requirements + +When working with images in Cohere multimodal embeddings, note the following requirements: + +* Maximum 1 image per request +* Maximum 5MB per image +* Supported formats: JPEG, PNG, WebP, GIF +* Images are converted to Data URI format (base64-encoded) before sending to the API + +The Spring AI Cohere integration handles the conversion automatically when you provide images through `Media` objects. + +== Sample Controller + +This will create a `DocumentEmbeddingModel` implementation that you can inject into your class. +Here is an example of a simple `@Controller` class that uses the multimodal embedding implementation. + +[source,application.properties] +---- +spring.ai.cohere.api-key=YOUR_API_KEY +spring.ai.model.embedding.multimodal=cohere +spring.ai.cohere.embedding.multimodal.options.model=embed-v4 +spring.ai.cohere.embedding.multimodal.options.input-type=classification +---- + +[source,java] +---- +@RestController +public class MultimodalEmbeddingController { + + private final DocumentEmbeddingModel embeddingModel; + + @Autowired + public MultimodalEmbeddingController(DocumentEmbeddingModel embeddingModel) { + this.embeddingModel = embeddingModel; + } + + @GetMapping("/ai/embedding/text") + public Map embedText(@RequestParam(value = "message", defaultValue = "Hello World") String message) { + Document document = new Document(message); + DocumentEmbeddingRequest request = new DocumentEmbeddingRequest( + List.of(document), + EmbeddingOptions.EMPTY); + + EmbeddingResponse embeddingResponse = this.embeddingModel.call(request); + return Map.of("embedding", embeddingResponse); + } + + @PostMapping("/ai/embedding/image") + public Map embedImage(@RequestParam("file") MultipartFile file) throws IOException { + Media imageMedia = new Media( + MimeTypeUtils.parseMimeType(file.getContentType()), + file.getResource()); + + Document document = new Document("", List.of(imageMedia), Map.of()); + DocumentEmbeddingRequest request = new DocumentEmbeddingRequest( + List.of(document), + EmbeddingOptions.EMPTY); + + EmbeddingResponse embeddingResponse = this.embeddingModel.call(request); + return Map.of("embedding", embeddingResponse); + } + + @PostMapping("/ai/embedding/multimodal") + public Map embedMultimodal( + @RequestParam(value = "message", defaultValue = "Describe this image") String message, + @RequestParam("file") MultipartFile file) throws IOException { + + Media imageMedia = new Media( + MimeTypeUtils.parseMimeType(file.getContentType()), + file.getResource()); + + Document document = new Document(message, List.of(imageMedia), Map.of()); + DocumentEmbeddingRequest request = new DocumentEmbeddingRequest( + List.of(document), + EmbeddingOptions.EMPTY); + + EmbeddingResponse embeddingResponse = this.embeddingModel.call(request); + return Map.of("embedding", embeddingResponse); + } +} +---- + +== Manual Configuration + +If you are not using Spring Boot, you can manually configure the Cohere Multimodal Embedding Model. +For this add the `spring-ai-cohere` dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-cohere + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-cohere' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +NOTE: The `spring-ai-cohere` dependency provides access also to the `CohereChatModel`. +For more information about the `CohereChatModel` refer to the link:../chat/cohere-chat.html[Cohere Chat Client] section. + +Next, create a `CohereMultimodalEmbeddingModel` instance and use it to compute embeddings for text and images: + +[source,java] +---- +var cohereApi = new CohereApi(System.getenv("COHERE_API_KEY")); + +var embeddingModel = CohereMultimodalEmbeddingModel.builder() + .cohereApi(cohereApi) + .options(CohereMultimodalEmbeddingOptions.builder() + .model("embed-v4") + .inputType(InputType.CLASSIFICATION) + .embeddingTypes(List.of(EmbeddingType.FLOAT)) + .build()) + .build(); + +// Embedding text +Document textDocument = new Document("Hello World"); + +// Embedding an image +Media imageMedia = new Media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.image.png")); +Document imageDocument = new Document("", List.of(imageMedia), Map.of()); + +// Embedding text with image +Document multimodalDocument = new Document("Describe this image", List.of(imageMedia), Map.of()); + +DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest( + List.of(textDocument, imageDocument, multimodalDocument), + EmbeddingOptions.EMPTY); + +EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest); + +// Each document gets its own embedding result +assertThat(embeddingResponse.getResults()).hasSize(3); +assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1536); +---- + +The `CohereMultimodalEmbeddingOptions` provides the configuration information for the embedding requests. +The options class offers a `builder()` for easy options creation. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/cohere-embeddings-text.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/cohere-embeddings-text.adoc new file mode 100644 index 00000000000..2a7b1cf72ec --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/cohere-embeddings-text.adoc @@ -0,0 +1,313 @@ += Cohere Text Embeddings + +Cohere supports two types of embeddings models, text and multimodal. +This document describes how to create text embeddings using the Cohere link:https://docs.cohere.com/docs/embeddings[Text embeddings API]. + +Embeddings are vectorial representations of text that capture the semantic meaning of paragraphs through their position in a high dimensional vector space. Cohere Embeddings API offers cutting-edge, state-of-the-art embeddings for text, which can be used for many NLP tasks. + +== Available Models + +Cohere provides several embedding models, each optimized for different use cases: + +[cols="2,2,1,4", stripes=even] +|==== +| Model | Dimensions | Use Case | Description + +| `embed-v4` +| 1024 +| General text +| Latest general-purpose embedding model suitable for semantic search, clustering, and text similarity tasks. Offers improved performance and multilingual support. + +| `embed-english-v3.0` +| 1024 +| English text +| Optimized for English language content. Ideal for semantic search and text classification tasks with English documents. + +| `embed-multilingual-v3.0` +| 1024 +| Multilingual text +| Supports over 100 languages. Perfect for applications requiring multilingual semantic search and text similarity. + +| `embed-english-light-v3.0` +| 384 +| English text (lightweight) +| Lightweight model for English content. Faster inference with reduced memory footprint while maintaining good accuracy. + +| `embed-multilingual-light-v3.0` +| 384 +| Multilingual text (lightweight) +| Lightweight multilingual model. Balances performance and resource usage for multilingual applications. +|==== + +When choosing a model: + +* Use `embed-v4` for the latest features and best overall performance +* Use `embed-english-v3.0` for high-quality English-only embeddings +* Use `embed-multilingual-v3.0` when working with multiple languages +* Use the "light" variants when you need faster inference or have resource constraints + +TIP: For multimodal embedding use cases (combining text and images), we recommend using the xref:api/embeddings/cohere-embeddings-multimodal.adoc[Cohere Multimodal Embedding model] instead. + +== Prerequisites + +You will need to create an API key with Cohere to access Cohere embedding models. + +Create an account at https://dashboard.cohere.com/welcome/register[Cohere registration page] and generate the token on the https://dashboard.cohere.com/api-keys[API Keys page]. + +The Spring AI project defines a configuration property named `spring.ai.cohere.api-key` that you should set to the value of the API Key obtained from dashboard.cohere.com. + +You can set this configuration property in your `application.properties` file: + +[source,properties] +---- +spring.ai.cohere.api-key= +---- + +For enhanced security when handling sensitive information like API keys, you can use Spring Expression Language (SpEL) to reference an environment variable: + +[source,yaml] +---- +# In application.yml +spring: + ai: + cohere: + api-key: ${COHERE_API_KEY} +---- + +[source,bash] +---- +# In your environment or .env file +export COHERE_API_KEY= +---- + +You can also set this configuration programmatically in your application code: + +[source,java] +---- +// Retrieve API key from a secure source or environment variable +String apiKey = System.getenv("COHERE_API_KEY"); +---- + +=== Add Repositories and BOM + +Spring AI artifacts are published in Maven Central and Spring Snapshot repositories. +Refer to the xref:getting-started.adoc#artifact-repositories[Artifact Repositories] section to add these repositories to your build system. + +To help with dependency management, Spring AI provides a BOM (bill of materials) to ensure that a consistent version of Spring AI is used throughout the entire project. Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build system. + + +== Auto-configuration + +[NOTE] +==== +There has been a significant change in the Spring AI auto-configuration, starter modules' artifact names. +Please refer to the https://docs.spring.io/spring-ai/reference/upgrade-notes.html[upgrade notes] for more information. +==== + +Spring AI provides Spring Boot auto-configuration for the Cohere Text Embedding Model. +To enable it add the following dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-starter-model-cohere + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-starter-model-cohere' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +=== Embedding Properties + +==== Retry Properties + +The prefix `spring.ai.retry` is used as the property prefix that lets you configure the retry mechanism for the Cohere Embedding model. + +[cols="3,5,1", stripes=even] +|==== +| Property | Description | Default + +| spring.ai.retry.max-attempts | Maximum number of retry attempts. | 10 +| spring.ai.retry.backoff.initial-interval | Initial sleep duration for the exponential backoff policy. | 2 sec. +| spring.ai.retry.backoff.multiplier | Backoff interval multiplier. | 5 +| spring.ai.retry.backoff.max-interval | Maximum backoff duration. | 3 min. +| spring.ai.retry.on-client-errors | If false, throw a NonTransientAiException, and do not attempt retry for `4xx` client error codes | false +| spring.ai.retry.exclude-on-http-codes | List of HTTP status codes that should not trigger a retry (e.g. to throw NonTransientAiException). | empty +| spring.ai.retry.on-http-codes | List of HTTP status codes that should trigger a retry (e.g. to throw TransientAiException). | empty +|==== + +==== Connection Properties + +The prefix `spring.ai.cohere` is used as the property prefix that lets you connect to Cohere. + +[cols="3,5,1", stripes=even] +|==== +| Property | Description | Default + +| spring.ai.cohere.base-url | The URL to connect to | https://api.cohere.com +| spring.ai.cohere.api-key | The API Key | - +|==== + +==== Configuration Properties + +[NOTE] +==== +Enabling and disabling of the embedding auto-configurations are now configured via top level properties with the prefix `spring.ai.model.embedding`. + +To enable, spring.ai.model.embedding=cohere (It is enabled by default) + +To disable, spring.ai.model.embedding=none (or any value which doesn't match cohere) + +This change is done to allow configuration of multiple models. +==== + +The prefix `spring.ai.cohere.embedding` is property prefix that configures the `EmbeddingModel` implementation for Cohere. + +[cols="3,5,1", stripes=even] +|==== +| Property | Description | Default + +| spring.ai.model.embedding | Enable Cohere embedding model. | cohere +| spring.ai.cohere.embedding.base-url | Optional overrides the spring.ai.cohere.base-url to provide embedding specific url | - +| spring.ai.cohere.embedding.api-key | Optional overrides the spring.ai.cohere.api-key to provide embedding specific api-key | - +| spring.ai.cohere.embedding.metadata-mode | Document content extraction mode. | EMBED +| spring.ai.cohere.embedding.options.model | The model to use | embed-v4 +| spring.ai.cohere.embedding.options.input-type | The type of input (search_document, search_query, classification, clustering) | classification +| spring.ai.cohere.embedding.options.embedding-types | The types of embeddings to return (float, int8, uint8, binary, ubinary) | [float] +| spring.ai.cohere.embedding.options.truncate | How to handle inputs longer than maximum token length (NONE, START, END) | - +|==== + +NOTE: You can override the common `spring.ai.cohere.base-url` and `spring.ai.cohere.api-key` for the `ChatModel` and `EmbeddingModel` implementations. +The `spring.ai.cohere.embedding.base-url` and `spring.ai.cohere.embedding.api-key` properties if set take precedence over the common properties. +Similarly, the `spring.ai.cohere.chat.base-url` and `spring.ai.cohere.chat.api-key` properties if set take precedence over the common properties. +This is useful if you want to use different Cohere accounts for different models and different model endpoints. + +TIP: All properties prefixed with `spring.ai.cohere.embedding.options` can be overridden at runtime by adding a request specific <> to the `EmbeddingRequest` call. + +== Runtime Options [[embedding-options]] + +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereEmbeddingOptions.java[CohereEmbeddingOptions.java] provides the Cohere configurations, such as the model to use and etc. + +The default options can be configured using the `spring.ai.cohere.embedding.options` properties as well. + +At start-time use the `CohereEmbeddingModel` constructor to set the default options used for all embedding requests. +At run-time you can override the default options, using a `CohereEmbeddingOptions` instance as part of your `EmbeddingRequest`. + +For example to override the default model name and input type for a specific request: + +[source,java] +---- +// Using embed-v4 for general text embeddings +EmbeddingResponse embeddingResponse = embeddingModel.call( + new EmbeddingRequest(List.of("Hello World", "World is big and salvation is near"), + CohereEmbeddingOptions.builder() + .model("embed-v4") + .inputType(InputType.SEARCH_DOCUMENT) + .build())); + +// Using embed-multilingual-v3.0 for multilingual content +EmbeddingResponse multilingualResponse = embeddingModel.call( + new EmbeddingRequest(List.of("Bonjour le monde", "Hola mundo"), + CohereEmbeddingOptions.builder() + .model("embed-multilingual-v3.0") + .inputType(InputType.CLUSTERING) + .build())); +---- + +== Understanding Input Types + +Cohere embeddings support different input types to optimize the embeddings for specific use cases: + +* `SEARCH_DOCUMENT`: Use when embedding documents to be retrieved in a search system +* `SEARCH_QUERY`: Use when embedding search queries to match against documents +* `CLASSIFICATION`: Use for text classification tasks +* `CLUSTERING`: Use for clustering documents by similarity + +For best results in semantic search applications, use `SEARCH_DOCUMENT` for your corpus and `SEARCH_QUERY` for user queries. + +== Sample Controller + +This will create a `EmbeddingModel` implementation that you can inject into your class. +Here is an example of a simple `@Controller` class that uses the `EmbeddingModel` implementation. + +[source,application.properties] +---- +spring.ai.cohere.api-key=YOUR_API_KEY +spring.ai.cohere.embedding.options.model=embed-v4 +spring.ai.cohere.embedding.options.input-type=classification +---- + +[source,java] +---- +@RestController +public class EmbeddingController { + + private final EmbeddingModel embeddingModel; + + @Autowired + public EmbeddingController(EmbeddingModel embeddingModel) { + this.embeddingModel = embeddingModel; + } + + @GetMapping("/ai/embedding") + public Map embed(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + var embeddingResponse = this.embeddingModel.embedForResponse(List.of(message)); + return Map.of("embedding", embeddingResponse); + } +} +---- + +== Manual Configuration + +If you are not using Spring Boot, you can manually configure the Cohere Text Embedding Model. +For this add the `spring-ai-cohere` dependency to your project's Maven `pom.xml` file: +[source, xml] +---- + + org.springframework.ai + spring-ai-cohere + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-cohere' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +NOTE: The `spring-ai-cohere` dependency provides access also to the `CohereChatModel`. +For more information about the `CohereChatModel` refer to the link:../chat/cohere-chat.html[Cohere Chat Client] section. + +Next, create a `CohereEmbeddingModel` instance and use it to compute the similarity between two input texts: + +[source,java] +---- +var cohereApi = new CohereApi(System.getenv("COHERE_API_KEY")); + +var embeddingModel = new CohereEmbeddingModel(this.cohereApi, + CohereEmbeddingOptions.builder() + .model("embed-v4") + .inputType(InputType.CLASSIFICATION) + .embeddingTypes(List.of(EmbeddingType.FLOAT)) + .build()); + +EmbeddingResponse embeddingResponse = this.embeddingModel + .embedForResponse(List.of("Hello World", "World is big and salvation is near")); +---- + +The `CohereEmbeddingOptions` provides the configuration information for the embedding requests. +The options class offers a `builder()` for easy options creation. diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/SpringAIModels.java b/spring-ai-model/src/main/java/org/springframework/ai/model/SpringAIModels.java index 19f8a8a3258..53361bba703 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/SpringAIModels.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/SpringAIModels.java @@ -60,4 +60,6 @@ private SpringAIModels() { public static final String ELEVEN_LABS = "elevenlabs"; + public static final String COHERE = "cohere"; + } diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-model-cohere/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-model-cohere/pom.xml new file mode 100644 index 00000000000..4b575c7edf9 --- /dev/null +++ b/spring-ai-spring-boot-starters/spring-ai-starter-model-cohere/pom.xml @@ -0,0 +1,54 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai-parent + 2.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-starter-model-cohere + jar + Spring AI Starter - Cohere + Spring AI Cohere Spring Boot Starter + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + org.springframework.boot + spring-boot-starter + + + + org.springframework.ai + spring-ai-autoconfigure-model-cohere + ${project.parent.version} + + + + org.springframework.ai + spring-ai-cohere + ${project.parent.version} + + + + org.springframework.ai + spring-ai-autoconfigure-model-chat-client + ${project.parent.version} + + + + org.springframework.ai + spring-ai-autoconfigure-model-chat-memory + ${project.parent.version} + + + +