diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/pom.xml b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/pom.xml
new file mode 100644
index 00000000000..737dfbb162c
--- /dev/null
+++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/pom.xml
@@ -0,0 +1,102 @@
+
+
+ 4.0.0
+
+ org.springframework.ai
+ spring-ai-parent
+ 2.0.0-SNAPSHOT
+ ../../../pom.xml
+
+ spring-ai-autoconfigure-model-cohere
+ jar
+ Spring AI Cohere Auto Configuration
+ Spring AI Cohere Auto Configuration
+ https://github.com/spring-projects/spring-ai
+
+
+ https://github.com/spring-projects/spring-ai
+ git://github.com/spring-projects/spring-ai.git
+ git@github.com:spring-projects/spring-ai.git
+
+
+
+
+
+
+ org.springframework.ai
+ spring-ai-cohere
+ ${project.parent.version}
+ true
+
+
+
+ org.springframework.ai
+ spring-ai-autoconfigure-model-tool
+ ${project.parent.version}
+
+
+
+ org.springframework.ai
+ spring-ai-autoconfigure-retry
+ ${project.parent.version}
+
+
+
+ org.springframework.ai
+ spring-ai-autoconfigure-model-chat-observation
+ ${project.parent.version}
+
+
+
+ org.springframework.boot
+ spring-boot-starter
+ true
+
+
+
+ org.springframework.boot
+ spring-boot-starter-webclient
+ true
+
+
+ org.springframework.boot
+ spring-boot-starter-restclient
+ true
+
+
+
+ org.springframework.boot
+ spring-boot-configuration-processor
+ true
+
+
+
+ org.springframework.boot
+ spring-boot-autoconfigure-processor
+ true
+
+
+
+
+ org.springframework.ai
+ spring-ai-test
+ ${project.parent.version}
+ test
+
+
+
+ org.springframework.boot
+ spring-boot-starter-test
+ test
+
+
+
+ org.mockito
+ mockito-core
+ test
+
+
+
+
diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereChatAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereChatAutoConfiguration.java
new file mode 100644
index 00000000000..4fee095d6c9
--- /dev/null
+++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereChatAutoConfiguration.java
@@ -0,0 +1,110 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.autoconfigure;
+
+import io.micrometer.observation.ObservationRegistry;
+
+import org.springframework.ai.chat.observation.ChatModelObservationConvention;
+import org.springframework.ai.cohere.api.CohereApi;
+import org.springframework.ai.cohere.chat.CohereChatModel;
+import org.springframework.ai.model.SpringAIModelProperties;
+import org.springframework.ai.model.SpringAIModels;
+import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
+import org.springframework.ai.model.tool.ToolCallingManager;
+import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
+import org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration;
+import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration;
+import org.springframework.beans.factory.ObjectProvider;
+import org.springframework.boot.autoconfigure.AutoConfiguration;
+import org.springframework.boot.autoconfigure.ImportAutoConfiguration;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
+import org.springframework.boot.context.properties.EnableConfigurationProperties;
+import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration;
+import org.springframework.boot.webclient.autoconfigure.WebClientAutoConfiguration;
+import org.springframework.context.annotation.Bean;
+import org.springframework.core.retry.RetryTemplate;
+import org.springframework.util.Assert;
+import org.springframework.util.StringUtils;
+import org.springframework.web.client.ResponseErrorHandler;
+import org.springframework.web.client.RestClient;
+import org.springframework.web.reactive.function.client.WebClient;
+
+/**
+ * Chat {@link AutoConfiguration Auto-configuration} for Cohere.
+ *
+ * @author Ricken Bazolo
+ */
+@AutoConfiguration(after = { RestClientAutoConfiguration.class, WebClientAutoConfiguration.class,
+ SpringAiRetryAutoConfiguration.class, ToolCallingAutoConfiguration.class })
+@EnableConfigurationProperties({ CohereCommonProperties.class, CohereChatProperties.class })
+@ConditionalOnProperty(name = SpringAIModelProperties.CHAT_MODEL, havingValue = SpringAIModels.COHERE,
+ matchIfMissing = true)
+@ConditionalOnClass(CohereApi.class)
+@ImportAutoConfiguration(classes = { SpringAiRetryAutoConfiguration.class, RestClientAutoConfiguration.class,
+ ToolCallingAutoConfiguration.class })
+public class CohereChatAutoConfiguration {
+
+ @Bean
+ @ConditionalOnMissingBean
+ public CohereChatModel chereChatModel(CohereCommonProperties commonProperties, CohereChatProperties chatProperties,
+ ObjectProvider restClientBuilderProvider,
+ ObjectProvider webClientBuilderProvider, ToolCallingManager toolCallingManager,
+ RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler,
+ ObjectProvider observationRegistry,
+ ObjectProvider observationConvention,
+ ObjectProvider cohereToolExecutionEligibilityPredicate) {
+ var cohereApi = cohereApi(chatProperties.getApiKey(), commonProperties.getApiKey(), chatProperties.getBaseUrl(),
+ commonProperties.getBaseUrl(), restClientBuilderProvider.getIfAvailable(RestClient::builder),
+ webClientBuilderProvider.getIfAvailable(WebClient::builder), responseErrorHandler);
+
+ var chatModel = CohereChatModel.builder()
+ .cohereApi(cohereApi)
+ .defaultOptions(chatProperties.getOptions())
+ .toolCallingManager(toolCallingManager)
+ .toolExecutionEligibilityPredicate(
+ cohereToolExecutionEligibilityPredicate.getIfUnique(DefaultToolExecutionEligibilityPredicate::new))
+ .retryTemplate(new RetryTemplate())
+ .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
+ .build();
+
+ observationConvention.ifAvailable(chatModel::setObservationConvention);
+
+ return chatModel;
+ }
+
+ private CohereApi cohereApi(String apiKey, String commonApiKey, String baseUrl, String commonBaseUrl,
+ RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
+ ResponseErrorHandler responseErrorHandler) {
+
+ var resolvedApiKey = StringUtils.hasText(apiKey) ? apiKey : commonApiKey;
+ var resoledBaseUrl = StringUtils.hasText(baseUrl) ? baseUrl : commonBaseUrl;
+
+ Assert.hasText(resolvedApiKey, "Cohere API key must be set");
+ Assert.hasText(resoledBaseUrl, "Cohere base URL must be set");
+
+ return CohereApi.builder()
+ .baseUrl(resoledBaseUrl)
+ .apiKey(resolvedApiKey)
+ .restClientBuilder(restClientBuilder)
+ .webClientBuilder(webClientBuilder)
+ .responseErrorHandler(responseErrorHandler)
+ .build();
+ }
+
+}
diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereChatProperties.java b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereChatProperties.java
new file mode 100644
index 00000000000..49ef712477f
--- /dev/null
+++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereChatProperties.java
@@ -0,0 +1,62 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.autoconfigure;
+
+import org.springframework.ai.cohere.api.CohereApi;
+import org.springframework.ai.cohere.chat.CohereChatOptions;
+import org.springframework.boot.context.properties.ConfigurationProperties;
+import org.springframework.boot.context.properties.NestedConfigurationProperty;
+
+/**
+ * Configuration properties for Cohere chat.
+ *
+ * @author Ricken Bazolo
+ */
+@ConfigurationProperties(CohereChatProperties.CONFIG_PREFIX)
+public class CohereChatProperties extends CohereParentProperties {
+
+ public static final String CONFIG_PREFIX = "spring.ai.cohere.chat";
+
+ public static final String DEFAULT_CHAT_MODEL = CohereApi.ChatModel.COMMAND_A_R7B.getValue();
+
+ private static final Double DEFAULT_TEMPERATURE = 0.3;
+
+ private static final Double DEFAULT_TOP_P = 1.0;
+
+ @NestedConfigurationProperty
+ private CohereChatOptions options = CohereChatOptions.builder()
+ .model(DEFAULT_CHAT_MODEL)
+ .temperature(DEFAULT_TEMPERATURE)
+ .topP(DEFAULT_TOP_P)
+ .presencePenalty(0.0)
+ .frequencyPenalty(0.0)
+ .logprobs(false)
+ .build();
+
+ public CohereChatProperties() {
+ super.setBaseUrl(CohereCommonProperties.DEFAULT_BASE_URL);
+ }
+
+ public CohereChatOptions getOptions() {
+ return this.options;
+ }
+
+ public void setOptions(CohereChatOptions options) {
+ this.options = options;
+ }
+
+}
diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereCommonProperties.java b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereCommonProperties.java
new file mode 100644
index 00000000000..db672b8328b
--- /dev/null
+++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereCommonProperties.java
@@ -0,0 +1,37 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.autoconfigure;
+
+import org.springframework.boot.context.properties.ConfigurationProperties;
+
+/**
+ * Common properties for Cohere.
+ *
+ * @author Ricken Bazolo
+ */
+@ConfigurationProperties(CohereCommonProperties.CONFIG_PREFIX)
+public class CohereCommonProperties extends CohereParentProperties {
+
+ public static final String CONFIG_PREFIX = "spring.ai.cohere";
+
+ public static final String DEFAULT_BASE_URL = "https://api.cohere.com";
+
+ public CohereCommonProperties() {
+ super.setBaseUrl(DEFAULT_BASE_URL);
+ }
+
+}
diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereEmbeddingAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereEmbeddingAutoConfiguration.java
new file mode 100644
index 00000000000..127610a2288
--- /dev/null
+++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereEmbeddingAutoConfiguration.java
@@ -0,0 +1,95 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.autoconfigure;
+
+import io.micrometer.observation.ObservationRegistry;
+
+import org.springframework.ai.cohere.api.CohereApi;
+import org.springframework.ai.cohere.embedding.CohereEmbeddingModel;
+import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
+import org.springframework.ai.model.SpringAIModelProperties;
+import org.springframework.ai.model.SpringAIModels;
+import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration;
+import org.springframework.beans.factory.ObjectProvider;
+import org.springframework.boot.autoconfigure.AutoConfiguration;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
+import org.springframework.boot.context.properties.EnableConfigurationProperties;
+import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration;
+import org.springframework.context.annotation.Bean;
+import org.springframework.core.retry.RetryTemplate;
+import org.springframework.util.Assert;
+import org.springframework.util.StringUtils;
+import org.springframework.web.client.ResponseErrorHandler;
+import org.springframework.web.client.RestClient;
+
+/**
+ * Embedding {@link AutoConfiguration Auto-configuration} for Cohere
+ *
+ * @author Ricken Bazolo
+ */
+@AutoConfiguration(after = { RestClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class })
+@EnableConfigurationProperties({ CohereCommonProperties.class, CohereEmbeddingProperties.class })
+@ConditionalOnClass(CohereApi.class)
+@ConditionalOnProperty(name = SpringAIModelProperties.EMBEDDING_MODEL, havingValue = SpringAIModels.COHERE,
+ matchIfMissing = true)
+public class CohereEmbeddingAutoConfiguration {
+
+ @Bean
+ @ConditionalOnMissingBean
+ public CohereEmbeddingModel mistralAiEmbeddingModel(CohereCommonProperties commonProperties,
+ CohereEmbeddingProperties embeddingProperties, ObjectProvider restClientBuilderProvider,
+ RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler,
+ ObjectProvider observationRegistry,
+ ObjectProvider observationConvention) {
+
+ var cohereApi = cohereApi(embeddingProperties.getApiKey(), commonProperties.getApiKey(),
+ embeddingProperties.getBaseUrl(), commonProperties.getBaseUrl(),
+ restClientBuilderProvider.getIfAvailable(RestClient::builder), responseErrorHandler);
+
+ var embeddingModel = CohereEmbeddingModel.builder()
+ .cohereApi(cohereApi)
+ .metadataMode(embeddingProperties.getMetadataMode())
+ .options(embeddingProperties.getOptions())
+ .retryTemplate(retryTemplate)
+ .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
+ .build();
+
+ observationConvention.ifAvailable(embeddingModel::setObservationConvention);
+
+ return embeddingModel;
+ }
+
+ private CohereApi cohereApi(String apiKey, String commonApiKey, String baseUrl, String commonBaseUrl,
+ RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {
+
+ var resolvedApiKey = StringUtils.hasText(apiKey) ? apiKey : commonApiKey;
+ var resoledBaseUrl = StringUtils.hasText(baseUrl) ? baseUrl : commonBaseUrl;
+
+ Assert.hasText(resolvedApiKey, "Cohere API key must be set");
+ Assert.hasText(resoledBaseUrl, "Cohere base URL must be set");
+
+ return CohereApi.builder()
+ .baseUrl(resoledBaseUrl)
+ .apiKey(resolvedApiKey)
+ .restClientBuilder(restClientBuilder)
+ .responseErrorHandler(responseErrorHandler)
+ .build();
+ }
+
+}
diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereEmbeddingProperties.java b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereEmbeddingProperties.java
new file mode 100644
index 00000000000..75acd6b3f8c
--- /dev/null
+++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereEmbeddingProperties.java
@@ -0,0 +1,66 @@
+/*
+ * Copyright 2023-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.autoconfigure;
+
+import java.util.List;
+
+import org.springframework.ai.cohere.api.CohereApi.EmbeddingModel;
+import org.springframework.ai.cohere.api.CohereApi.EmbeddingType;
+import org.springframework.ai.cohere.embedding.CohereEmbeddingOptions;
+import org.springframework.ai.document.MetadataMode;
+import org.springframework.boot.context.properties.ConfigurationProperties;
+import org.springframework.boot.context.properties.NestedConfigurationProperty;
+
+/**
+ * Configuration properties for Cohere embedding model.
+ *
+ * @author Ricken Bazolo
+ */
+@ConfigurationProperties(CohereEmbeddingProperties.CONFIG_PREFIX)
+public class CohereEmbeddingProperties extends CohereParentProperties {
+
+ public static final String CONFIG_PREFIX = "spring.ai.cohere.embedding";
+
+ public static final String DEFAULT_EMBEDDING_MODEL = EmbeddingModel.EMBED_V4.getValue();
+
+ public static final String DEFAULT_ENCODING_FORMAT = EmbeddingType.FLOAT.name();
+
+ public MetadataMode metadataMode = MetadataMode.EMBED;
+
+ @NestedConfigurationProperty
+ private final CohereEmbeddingOptions options = CohereEmbeddingOptions.builder()
+ .model(DEFAULT_EMBEDDING_MODEL)
+ .embeddingTypes(List.of(EmbeddingType.valueOf(DEFAULT_ENCODING_FORMAT)))
+ .build();
+
+ public CohereEmbeddingProperties() {
+ super.setBaseUrl(CohereCommonProperties.DEFAULT_BASE_URL);
+ }
+
+ public CohereEmbeddingOptions getOptions() {
+ return this.options;
+ }
+
+ public MetadataMode getMetadataMode() {
+ return this.metadataMode;
+ }
+
+ public void setMetadataMode(MetadataMode metadataMode) {
+ this.metadataMode = metadataMode;
+ }
+
+}
diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereMultimodalEmbeddingAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereMultimodalEmbeddingAutoConfiguration.java
new file mode 100644
index 00000000000..72901f91c1e
--- /dev/null
+++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereMultimodalEmbeddingAutoConfiguration.java
@@ -0,0 +1,88 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.autoconfigure;
+
+import io.micrometer.observation.ObservationRegistry;
+
+import org.springframework.ai.cohere.api.CohereApi;
+import org.springframework.ai.cohere.embedding.CohereMultimodalEmbeddingModel;
+import org.springframework.ai.model.SpringAIModelProperties;
+import org.springframework.ai.model.SpringAIModels;
+import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration;
+import org.springframework.beans.factory.ObjectProvider;
+import org.springframework.boot.autoconfigure.AutoConfiguration;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
+import org.springframework.boot.context.properties.EnableConfigurationProperties;
+import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration;
+import org.springframework.context.annotation.Bean;
+import org.springframework.core.retry.RetryTemplate;
+import org.springframework.util.Assert;
+import org.springframework.util.StringUtils;
+import org.springframework.web.client.ResponseErrorHandler;
+import org.springframework.web.client.RestClient;
+
+/**
+ * Multimodal Embedding {@link AutoConfiguration Auto-configuration} for Cohere
+ *
+ * @author Ricken Bazolo
+ */
+@AutoConfiguration(after = { RestClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class })
+@EnableConfigurationProperties({ CohereCommonProperties.class, CohereMultimodalEmbeddingProperties.class })
+@ConditionalOnClass({ CohereApi.class, CohereMultimodalEmbeddingModel.class })
+@ConditionalOnProperty(name = SpringAIModelProperties.MULTI_MODAL_EMBEDDING_MODEL, havingValue = SpringAIModels.COHERE,
+ matchIfMissing = true)
+public class CohereMultimodalEmbeddingAutoConfiguration {
+
+ @Bean
+ @ConditionalOnMissingBean
+ public CohereMultimodalEmbeddingModel cohereMultimodalEmbeddingModel(CohereCommonProperties commonProperties,
+ CohereMultimodalEmbeddingProperties embeddingProperties,
+ ObjectProvider restClientBuilderProvider, RetryTemplate retryTemplate,
+ ResponseErrorHandler responseErrorHandler, ObjectProvider observationRegistry) {
+
+ var cohereApi = cohereApi(embeddingProperties.getApiKey(), commonProperties.getApiKey(),
+ embeddingProperties.getBaseUrl(), commonProperties.getBaseUrl(),
+ restClientBuilderProvider.getIfAvailable(RestClient::builder), responseErrorHandler);
+
+ return CohereMultimodalEmbeddingModel.builder()
+ .cohereApi(cohereApi)
+ .options(embeddingProperties.getOptions())
+ .retryTemplate(retryTemplate)
+ .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
+ .build();
+ }
+
+ private CohereApi cohereApi(String apiKey, String commonApiKey, String baseUrl, String commonBaseUrl,
+ RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {
+
+ var resolvedApiKey = StringUtils.hasText(apiKey) ? apiKey : commonApiKey;
+ var resoledBaseUrl = StringUtils.hasText(baseUrl) ? baseUrl : commonBaseUrl;
+
+ Assert.hasText(resolvedApiKey, "Cohere API key must be set");
+ Assert.hasText(resoledBaseUrl, "Cohere base URL must be set");
+
+ return CohereApi.builder()
+ .baseUrl(resoledBaseUrl)
+ .apiKey(resolvedApiKey)
+ .restClientBuilder(restClientBuilder)
+ .responseErrorHandler(responseErrorHandler)
+ .build();
+ }
+
+}
diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereMultimodalEmbeddingProperties.java b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereMultimodalEmbeddingProperties.java
new file mode 100644
index 00000000000..8caa70cdcc8
--- /dev/null
+++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereMultimodalEmbeddingProperties.java
@@ -0,0 +1,55 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.autoconfigure;
+
+import java.util.List;
+
+import org.springframework.ai.cohere.api.CohereApi.EmbeddingModel;
+import org.springframework.ai.cohere.api.CohereApi.EmbeddingType;
+import org.springframework.ai.cohere.embedding.CohereMultimodalEmbeddingOptions;
+import org.springframework.boot.context.properties.ConfigurationProperties;
+import org.springframework.boot.context.properties.NestedConfigurationProperty;
+
+/**
+ * Configuration properties for Cohere multimodal embedding model.
+ *
+ * @author Ricken Bazolo
+ */
+@ConfigurationProperties(CohereMultimodalEmbeddingProperties.CONFIG_PREFIX)
+public class CohereMultimodalEmbeddingProperties extends CohereParentProperties {
+
+ public static final String CONFIG_PREFIX = "spring.ai.cohere.embedding.multimodal";
+
+ public static final String DEFAULT_EMBEDDING_MODEL = EmbeddingModel.EMBED_V4.getValue();
+
+ public static final String DEFAULT_ENCODING_FORMAT = EmbeddingType.FLOAT.name();
+
+ @NestedConfigurationProperty
+ private final CohereMultimodalEmbeddingOptions options = CohereMultimodalEmbeddingOptions.builder()
+ .model(DEFAULT_EMBEDDING_MODEL)
+ .embeddingTypes(List.of(EmbeddingType.valueOf(DEFAULT_ENCODING_FORMAT)))
+ .build();
+
+ public CohereMultimodalEmbeddingProperties() {
+ super.setBaseUrl(CohereCommonProperties.DEFAULT_BASE_URL);
+ }
+
+ public CohereMultimodalEmbeddingOptions getOptions() {
+ return this.options;
+ }
+
+}
diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereParentProperties.java b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereParentProperties.java
new file mode 100644
index 00000000000..0561690d7bb
--- /dev/null
+++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/java/org/springframework/ai/cohere/autoconfigure/CohereParentProperties.java
@@ -0,0 +1,46 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.autoconfigure;
+
+/**
+ * Parent properties for Cohere.
+ *
+ * @author Ricken Bazolo
+ */
+public class CohereParentProperties {
+
+ private String apiKey;
+
+ private String baseUrl;
+
+ public String getApiKey() {
+ return this.apiKey;
+ }
+
+ public void setApiKey(String apiKey) {
+ this.apiKey = apiKey;
+ }
+
+ public String getBaseUrl() {
+ return this.baseUrl;
+ }
+
+ public void setBaseUrl(String baseUrl) {
+ this.baseUrl = baseUrl;
+ }
+
+}
diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/resources/META-INF/additional-spring-configuration-metadata.json b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/resources/META-INF/additional-spring-configuration-metadata.json
new file mode 100644
index 00000000000..5952657975c
--- /dev/null
+++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/resources/META-INF/additional-spring-configuration-metadata.json
@@ -0,0 +1,11 @@
+{
+ "groups": [
+ {
+ "name": "spring.ai.cohere.chat.options.tool-choice",
+ "type": "org.springframework.ai.cohere.api.CohereApi$ChatCompletionRequest$ToolChoice",
+ "sourceType": "org.springframework.ai.cohere.chat.CohereChatOptions"
+ }
+ ],
+ "properties": [],
+ "hints": []
+}
diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports
new file mode 100644
index 00000000000..1643591155e
--- /dev/null
+++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports
@@ -0,0 +1,18 @@
+#
+# Copyright 2025-2025 the original author or authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+org.springframework.ai.cohere.autoconfigure.CohereChatAutoConfiguration
+org.springframework.ai.cohere.autoconfigure.CohereEmbeddingAutoConfiguration
+org.springframework.ai.cohere.autoconfigure.CohereMultimodalEmbeddingAutoConfiguration
diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/test/java/org/springframework/ai/cohere/autoconfigure/CohereAutoConfigurationIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/test/java/org/springframework/ai/cohere/autoconfigure/CohereAutoConfigurationIT.java
new file mode 100644
index 00000000000..d8ec9261f35
--- /dev/null
+++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/test/java/org/springframework/ai/cohere/autoconfigure/CohereAutoConfigurationIT.java
@@ -0,0 +1,127 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.autoconfigure;
+
+import java.util.List;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
+import reactor.core.publisher.Flux;
+
+import org.springframework.ai.cohere.chat.CohereChatModel;
+import org.springframework.ai.cohere.embedding.CohereEmbeddingModel;
+import org.springframework.ai.document.Document;
+import org.springframework.ai.embedding.DocumentEmbeddingRequest;
+import org.springframework.ai.embedding.EmbeddingOptions;
+import org.springframework.ai.embedding.EmbeddingResponse;
+import org.springframework.ai.utils.SpringAiTestAutoConfigurations;
+import org.springframework.boot.autoconfigure.AutoConfigurations;
+import org.springframework.boot.test.context.runner.ApplicationContextRunner;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * @author Ricken Bazolo
+ */
+@EnabledIfEnvironmentVariable(named = "COHERE_API_KEY", matches = ".*")
+public class CohereAutoConfigurationIT {
+
+ private static final Log logger = LogFactory.getLog(CohereAutoConfigurationIT.class);
+
+ private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
+ .withPropertyValues("spring.ai.cohere.apiKey=" + System.getenv("COHERE_API_KEY"))
+ .withConfiguration(SpringAiTestAutoConfigurations.of(CohereChatAutoConfiguration.class));
+
+ @Test
+ void generate() {
+ this.contextRunner.withConfiguration(AutoConfigurations.of(CohereChatAutoConfiguration.class)).run(context -> {
+ CohereChatModel chatModel = context.getBean(CohereChatModel.class);
+ String response = chatModel.call("Hello");
+ assertThat(response).isNotEmpty();
+ logger.info("Response: " + response);
+ });
+ }
+
+ @Test
+ void embedding() {
+ this.contextRunner.withConfiguration(SpringAiTestAutoConfigurations.of(CohereEmbeddingAutoConfiguration.class))
+ .run(context -> {
+ CohereEmbeddingModel embeddingModel = context.getBean(CohereEmbeddingModel.class);
+
+ EmbeddingResponse embeddingResponse = embeddingModel
+ .embedForResponse(List.of("Hello World", "World is big and salvation is near"));
+ assertThat(embeddingResponse.getResults()).hasSize(2);
+ assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
+ assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0);
+ assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty();
+ assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1);
+
+ assertThat(embeddingModel.dimensions()).isEqualTo(1536);
+ });
+ }
+
+ @Test
+ void generateStreaming() {
+ this.contextRunner.withConfiguration(SpringAiTestAutoConfigurations.of(CohereChatAutoConfiguration.class))
+ .run(context -> {
+ CohereChatModel chatModel = context.getBean(CohereChatModel.class);
+ Flux responseFlux = chatModel
+ .stream(new org.springframework.ai.chat.prompt.Prompt(
+ new org.springframework.ai.chat.messages.UserMessage("Hello")));
+ String response = responseFlux.collectList()
+ .block()
+ .stream()
+ .map(chatResponse -> chatResponse.getResults().get(0).getOutput().getText())
+ .collect(java.util.stream.Collectors.joining());
+
+ assertThat(response).isNotEmpty();
+ logger.info("Response: " + response);
+ });
+ }
+
+ @Test
+ public void multimodalEmbedding() {
+ this.contextRunner
+ .withConfiguration(SpringAiTestAutoConfigurations.of(CohereMultimodalEmbeddingAutoConfiguration.class))
+ .run(context -> {
+ var multimodalEmbeddingProperties = context.getBean(CohereMultimodalEmbeddingProperties.class);
+
+ assertThat(multimodalEmbeddingProperties).isNotNull();
+
+ var multiModelEmbeddingModel = context
+ .getBean(org.springframework.ai.cohere.embedding.CohereMultimodalEmbeddingModel.class);
+
+ assertThat(multiModelEmbeddingModel).isNotNull();
+
+ var document = new Document("Hello World");
+
+ DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(List.of(document),
+ EmbeddingOptions.builder().build());
+
+ EmbeddingResponse embeddingResponse = multiModelEmbeddingModel.call(embeddingRequest);
+ assertThat(embeddingResponse.getResults()).hasSize(1);
+ assertThat(embeddingResponse.getResults().get(0)).isNotNull();
+ assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1536);
+
+ assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1536);
+
+ });
+ }
+
+}
diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/test/java/org/springframework/ai/cohere/autoconfigure/CohereModelConfigurationTests.java b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/test/java/org/springframework/ai/cohere/autoconfigure/CohereModelConfigurationTests.java
new file mode 100644
index 00000000000..79ad7a6e21f
--- /dev/null
+++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/test/java/org/springframework/ai/cohere/autoconfigure/CohereModelConfigurationTests.java
@@ -0,0 +1,106 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.autoconfigure;
+
+import org.junit.jupiter.api.Test;
+
+import org.springframework.ai.cohere.chat.CohereChatModel;
+import org.springframework.ai.cohere.embedding.CohereEmbeddingModel;
+import org.springframework.ai.cohere.embedding.CohereMultimodalEmbeddingModel;
+import org.springframework.ai.utils.SpringAiTestAutoConfigurations;
+import org.springframework.boot.test.context.runner.ApplicationContextRunner;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * Unit Tests for Cohere auto-configurations conditional enabling of models.
+ *
+ * @author Ricken Bazolo
+ */
+public class CohereModelConfigurationTests {
+
+ private final ApplicationContextRunner chatContextRunner = new ApplicationContextRunner()
+ .withPropertyValues("spring.ai.cohere.apiKey=" + System.getenv("COHERE_API_KEY"))
+ .withConfiguration(SpringAiTestAutoConfigurations.of(CohereChatAutoConfiguration.class));
+
+ private final ApplicationContextRunner embeddingContextRunner = new ApplicationContextRunner()
+ .withPropertyValues("spring.ai.cohere.apiKey=" + System.getenv("COHERE_API_KEY"))
+ .withConfiguration(SpringAiTestAutoConfigurations.of(CohereEmbeddingAutoConfiguration.class));
+
+ private final ApplicationContextRunner embeddingMultimodalContextRunner = new ApplicationContextRunner()
+ .withPropertyValues("spring.ai.cohere.apiKey=" + System.getenv("COHERE_API_KEY"))
+ .withConfiguration(SpringAiTestAutoConfigurations.of(CohereMultimodalEmbeddingAutoConfiguration.class));
+
+ @Test
+ void chatModelActivation() {
+ this.chatContextRunner.run(context -> {
+ assertThat(context.getBeansOfType(CohereChatProperties.class)).isNotEmpty();
+ assertThat(context.getBeansOfType(CohereChatModel.class)).isNotEmpty();
+ assertThat(context.getBeansOfType(CohereEmbeddingProperties.class)).isEmpty();
+ assertThat(context.getBeansOfType(CohereEmbeddingModel.class)).isEmpty();
+ });
+
+ this.chatContextRunner.withPropertyValues("spring.ai.model.chat=none", "spring.ai.model.embedding=none")
+ .run(context -> {
+ assertThat(context.getBeansOfType(CohereChatProperties.class)).isEmpty();
+ assertThat(context.getBeansOfType(CohereChatModel.class)).isEmpty();
+ });
+
+ this.chatContextRunner.withPropertyValues("spring.ai.model.chat=cohere", "spring.ai.model.embedding=none")
+ .run(context -> {
+ assertThat(context.getBeansOfType(CohereChatProperties.class)).isNotEmpty();
+ assertThat(context.getBeansOfType(CohereChatModel.class)).isNotEmpty();
+ assertThat(context.getBeansOfType(CohereEmbeddingProperties.class)).isEmpty();
+ assertThat(context.getBeansOfType(CohereEmbeddingModel.class)).isEmpty();
+ });
+ }
+
+ @Test
+ void embeddingModelActivation() {
+ this.embeddingContextRunner
+ .run(context -> assertThat(context.getBeansOfType(CohereEmbeddingModel.class)).isNotEmpty());
+
+ this.embeddingContextRunner.withPropertyValues("spring.ai.model.embedding=none").run(context -> {
+ assertThat(context.getBeansOfType(CohereEmbeddingProperties.class)).isEmpty();
+ assertThat(context.getBeansOfType(CohereEmbeddingModel.class)).isEmpty();
+ });
+
+ this.embeddingContextRunner.withPropertyValues("spring.ai.model.embedding=cohere").run(context -> {
+ assertThat(context.getBeansOfType(CohereEmbeddingProperties.class)).isNotEmpty();
+ assertThat(context.getBeansOfType(CohereEmbeddingModel.class)).isNotEmpty();
+ });
+ }
+
+ @Test
+ void multimodalEmbeddingActivation() {
+ this.embeddingMultimodalContextRunner
+ .run(context -> assertThat(context.getBeansOfType(CohereMultimodalEmbeddingModel.class)).isNotEmpty());
+
+ this.embeddingMultimodalContextRunner.withPropertyValues("spring.ai.model.embedding.multimodal=none")
+ .run(context -> {
+ assertThat(context.getBeansOfType(CohereMultimodalEmbeddingProperties.class)).isEmpty();
+ assertThat(context.getBeansOfType(CohereMultimodalEmbeddingModel.class)).isEmpty();
+ });
+
+ this.embeddingMultimodalContextRunner.withPropertyValues("spring.ai.model.embedding.multimodal=cohere")
+ .run(context -> {
+ assertThat(context.getBeansOfType(CohereMultimodalEmbeddingProperties.class)).isNotEmpty();
+ assertThat(context.getBeansOfType(CohereMultimodalEmbeddingModel.class)).isNotEmpty();
+ });
+ }
+
+}
diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/test/java/org/springframework/ai/cohere/autoconfigure/CoherePropertiesTests.java b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/test/java/org/springframework/ai/cohere/autoconfigure/CoherePropertiesTests.java
new file mode 100644
index 00000000000..ff7f6fd88b4
--- /dev/null
+++ b/auto-configurations/models/spring-ai-autoconfigure-model-cohere/src/test/java/org/springframework/ai/cohere/autoconfigure/CoherePropertiesTests.java
@@ -0,0 +1,148 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.autoconfigure;
+
+import org.junit.jupiter.api.Test;
+
+import org.springframework.ai.cohere.api.CohereApi;
+import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration;
+import org.springframework.ai.utils.SpringAiTestAutoConfigurations;
+import org.springframework.boot.autoconfigure.AutoConfigurations;
+import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration;
+import org.springframework.boot.test.context.runner.ApplicationContextRunner;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * Unit Tests for {@link CohereCommonProperties}.
+ */
+public class CoherePropertiesTests {
+
+ @Test
+ public void chatOptionsTest() {
+
+ new ApplicationContextRunner()
+ .withPropertyValues("spring.ai.cohere.base-url=TEST_BASE_URL", "spring.ai.cohere.api-key=abc123",
+ "spring.ai.cohere.chat.options.tools[0].function.name=myFunction1",
+ "spring.ai.cohere.chat.options.tools[0].function.description=function description",
+ "spring.ai.cohere.chat.options.tools[0].function.jsonSchema=" + """
+ {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The city and state e.g. San Francisco, CA"
+ },
+ "lat": {
+ "type": "number",
+ "description": "The city latitude"
+ },
+ "lon": {
+ "type": "number",
+ "description": "The city longitude"
+ },
+ "unit": {
+ "type": "string",
+ "enum": ["c", "f"]
+ }
+ },
+ "required": ["location", "lat", "lon", "unit"]
+ }
+ """)
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, CohereChatAutoConfiguration.class))
+ .run(context -> {
+
+ var chatProperties = context.getBean(CohereChatProperties.class);
+
+ var tool = chatProperties.getOptions().getTools().get(0);
+ assertThat(tool.getType()).isEqualTo(CohereApi.FunctionTool.Type.FUNCTION);
+ var function = tool.getFunction();
+ assertThat(function.getName()).isEqualTo("myFunction1");
+ assertThat(function.getDescription()).isEqualTo("function description");
+ assertThat(function.getParameters()).isNotEmpty();
+ });
+ }
+
+ @Test
+ public void embeddingProperties() {
+
+ new ApplicationContextRunner()
+ .withPropertyValues("spring.ai.cohere.base-url=TEST_BASE_URL", "spring.ai.cohere.api-key=abc123",
+ "spring.ai.cohere.embedding.options.model=MODEL_XYZ")
+ .withConfiguration(SpringAiTestAutoConfigurations.of(CohereEmbeddingAutoConfiguration.class))
+ .run(context -> {
+ var embeddingProperties = context.getBean(CohereEmbeddingProperties.class);
+ var connectionProperties = context.getBean(CohereCommonProperties.class);
+
+ assertThat(connectionProperties.getApiKey()).isEqualTo("abc123");
+ assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
+
+ assertThat(embeddingProperties.getApiKey()).isNull();
+ assertThat(embeddingProperties.getBaseUrl()).isEqualTo(CohereCommonProperties.DEFAULT_BASE_URL);
+
+ assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");
+ });
+ }
+
+ @Test
+ public void embeddingOverrideConnectionProperties() {
+
+ new ApplicationContextRunner()
+ .withPropertyValues("spring.ai.cohere.base-url=TEST_BASE_URL", "spring.ai.cohere.api-key=abc123",
+ "spring.ai.cohere.embedding.base-url=TEST_BASE_URL2", "spring.ai.cohere.embedding.api-key=456",
+ "spring.ai.cohere.embedding.options.model=MODEL_XYZ")
+ .withConfiguration(SpringAiTestAutoConfigurations.of(CohereEmbeddingAutoConfiguration.class))
+ .run(context -> {
+ var embeddingProperties = context.getBean(CohereEmbeddingProperties.class);
+ var connectionProperties = context.getBean(CohereCommonProperties.class);
+
+ assertThat(connectionProperties.getApiKey()).isEqualTo("abc123");
+ assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
+
+ assertThat(embeddingProperties.getApiKey()).isEqualTo("456");
+ assertThat(embeddingProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2");
+
+ assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");
+ });
+ }
+
+ @Test
+ public void embeddingOptionsTest() {
+
+ new ApplicationContextRunner()
+ .withPropertyValues("spring.ai.cohere.api-key=API_KEY", "spring.ai.cohere.base-url=TEST_BASE_URL",
+ "spring.ai.cohere.embedding.options.model=MODEL_XYZ",
+ "spring.ai.cohere.embedding.options.embedding-types[0]=FLOAT",
+ "spring.ai.cohere.embedding.options.input-type=search_document",
+ "spring.ai.cohere.embedding.options.truncate=END")
+ .withConfiguration(SpringAiTestAutoConfigurations.of(CohereEmbeddingAutoConfiguration.class))
+ .run(context -> {
+ var connectionProperties = context.getBean(CohereCommonProperties.class);
+ var embeddingProperties = context.getBean(CohereEmbeddingProperties.class);
+
+ assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
+ assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY");
+
+ assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");
+ assertThat(embeddingProperties.getOptions().getEmbeddingTypes().get(0).name()).isEqualTo("FLOAT");
+ assertThat(embeddingProperties.getOptions().getTruncate().name()).isEqualTo("END");
+ assertThat(embeddingProperties.getOptions().getInputType().name()).isEqualTo("SEARCH_DOCUMENT");
+ });
+ }
+
+}
diff --git a/models/spring-ai-cohere/README.md b/models/spring-ai-cohere/README.md
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/models/spring-ai-cohere/pom.xml b/models/spring-ai-cohere/pom.xml
new file mode 100644
index 00000000000..4b4d97dc4c8
--- /dev/null
+++ b/models/spring-ai-cohere/pom.xml
@@ -0,0 +1,90 @@
+
+
+
+
+ 4.0.0
+
+ org.springframework.ai
+ spring-ai-parent
+ 2.0.0-SNAPSHOT
+ ../../pom.xml
+
+ spring-ai-cohere
+ jar
+ Spring AI Model - Cohere
+ Cohere models support
+ https://github.com/spring-projects/spring-ai
+
+
+ https://github.com/spring-projects/spring-ai
+ git://github.com/spring-projects/spring-ai.git
+ git@github.com:spring-projects/spring-ai.git
+
+
+
+
+
+
+
+
+
+ org.springframework.ai
+ spring-ai-model
+ ${project.parent.version}
+
+
+
+ org.springframework.ai
+ spring-ai-retry
+ ${project.parent.version}
+
+
+
+
+ org.springframework
+ spring-context-support
+
+
+
+ org.springframework
+ spring-webflux
+
+
+
+ org.slf4j
+ slf4j-api
+
+
+
+
+ org.springframework.ai
+ spring-ai-test
+ ${project.version}
+ test
+
+
+
+ io.micrometer
+ micrometer-observation-test
+ test
+
+
+
+
+
diff --git a/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/aot/CohereRuntimeHints.java b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/aot/CohereRuntimeHints.java
new file mode 100644
index 00000000000..b25b5a4843d
--- /dev/null
+++ b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/aot/CohereRuntimeHints.java
@@ -0,0 +1,42 @@
+/*
+ * Copyright 2023-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.aot;
+
+import org.springframework.aot.hint.MemberCategory;
+import org.springframework.aot.hint.RuntimeHints;
+import org.springframework.aot.hint.RuntimeHintsRegistrar;
+
+import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage;
+
+/**
+ * The CohereRuntimeHints class is responsible for registering runtime hints for Cohere AI
+ * API classes.
+ *
+ * @author Ricken Bazolo
+ */
+public class CohereRuntimeHints implements RuntimeHintsRegistrar {
+
+ @Override
+ public void registerHints(final RuntimeHints hints, final ClassLoader classLoader) {
+ var mcs = MemberCategory.values();
+
+ for (var tr : findJsonAnnotatedClassesInPackage("org.springframework.ai.cohere")) {
+ hints.reflection().registerType(tr, mcs);
+ }
+ }
+
+}
diff --git a/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/api/CohereApi.java b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/api/CohereApi.java
new file mode 100644
index 00000000000..f79808832da
--- /dev/null
+++ b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/api/CohereApi.java
@@ -0,0 +1,1332 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.api;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.function.Consumer;
+import java.util.function.Predicate;
+
+import com.fasterxml.jackson.annotation.JsonFormat;
+import com.fasterxml.jackson.annotation.JsonIgnore;
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+
+import org.springframework.ai.model.ChatModelDescription;
+import org.springframework.ai.model.ModelOptionsUtils;
+import org.springframework.ai.observation.conventions.AiProvider;
+import org.springframework.ai.retry.RetryUtils;
+import org.springframework.core.ParameterizedTypeReference;
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.MediaType;
+import org.springframework.http.ResponseEntity;
+import org.springframework.util.Assert;
+import org.springframework.util.CollectionUtils;
+import org.springframework.web.client.ResponseErrorHandler;
+import org.springframework.web.client.RestClient;
+import org.springframework.web.reactive.function.client.WebClient;
+
+/**
+ * Java Client library for Cohere Platform. Provides implementation for the
+ * Chat and
+ * Chat Stream
+ * Embedding API.
+ *
+ * Implements Synchronous and Streaming chat completion and supports latest
+ * Function Calling features.
+ *
+ *
+ * @author Ricken Bazolo
+ */
+public class CohereApi {
+
+ public static final String PROVIDER_NAME = AiProvider.COHERE.value();
+
+ private static final String DEFAULT_BASE_URL = "https://api.cohere.com";
+
+ private static final Predicate SSE_DONE_PREDICATE = "[DONE]"::equals;
+
+ private final RestClient restClient;
+
+ private final WebClient webClient;
+
+ private final CohereStreamFunctionCallingHelper chunkMerger = new CohereStreamFunctionCallingHelper();
+
+ /**
+ * Create a new client api with DEFAULT_BASE_URL
+ * @param cohereApiKey Cohere api Key.
+ */
+ public CohereApi(String cohereApiKey) {
+ this(DEFAULT_BASE_URL, cohereApiKey);
+ }
+
+ /**
+ * Create a new client api.
+ * @param baseUrl api base URL.
+ * @param cohereApiKey Cohere api Key.
+ */
+ public CohereApi(String baseUrl, String cohereApiKey) {
+ this(baseUrl, cohereApiKey, RestClient.builder(), WebClient.builder(),
+ RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
+ }
+
+ /**
+ * Create a new client api.
+ * @param baseUrl api base URL.
+ * @param cohereApiKey Cohere api Key.
+ * @param restClientBuilder RestClient builder.
+ * @param responseErrorHandler Response error handler.
+ */
+ public CohereApi(String baseUrl, String cohereApiKey, RestClient.Builder restClientBuilder,
+ WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) {
+
+ Consumer jsonContentHeaders = headers -> {
+ headers.setBearerAuth(cohereApiKey);
+ headers.setContentType(MediaType.APPLICATION_JSON);
+ };
+
+ this.restClient = restClientBuilder.baseUrl(baseUrl)
+ .defaultHeaders(jsonContentHeaders)
+ .defaultStatusHandler(responseErrorHandler)
+ .build();
+
+ this.webClient = webClientBuilder.clone().baseUrl(baseUrl).defaultHeaders(jsonContentHeaders).build();
+ }
+
+ /**
+ * Creates a model response for the given chat conversation.
+ * @param chatRequest The chat completion request.
+ * @return Entity response with {@link ChatCompletion} as a body and HTTP status code
+ * and headers.
+ */
+ public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) {
+
+ Assert.notNull(chatRequest, "The request body can not be null.");
+ Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false.");
+
+ return this.restClient.post().uri("/v2/chat/").body(chatRequest).retrieve().toEntity(ChatCompletion.class);
+ }
+
+ /**
+ * Creates an embedding vector representing the input text, token array, or images.
+ * @param embeddingRequest The embedding request.
+ * @return Returns {@link EmbeddingResponse} with embeddings data.
+ * @param Type of the entity in the data list. Can be a {@link String} or
+ * {@link List} of tokens (e.g. Integers). For embedding multiple inputs in a single
+ * request, You can pass a {@link List} of {@link String} or {@link List} of
+ * {@link List} of tokens. For example:
+ *
+ *
{@code List.of("text1", "text2", "text3")}
+ */
+ public ResponseEntity embeddings(EmbeddingRequest embeddingRequest) {
+
+ Assert.notNull(embeddingRequest, "The request body can not be null.");
+
+ boolean hasTexts = !CollectionUtils.isEmpty(embeddingRequest.texts);
+ boolean hasImages = !CollectionUtils.isEmpty(embeddingRequest.images);
+
+ Assert.isTrue(hasTexts || hasImages, "Either texts or images must be provided");
+ Assert.isTrue(!(hasTexts && hasImages), "Cannot provide both texts and images in the same request");
+
+ if (hasTexts) {
+ Assert.isTrue(embeddingRequest.texts.size() <= 96, "The texts list must be 96 items or less");
+ }
+
+ if (hasImages) {
+ Assert.isTrue(embeddingRequest.images.size() <= 1, "Only one image per request is supported");
+ }
+
+ return this.restClient.post()
+ .uri("/v2/embed")
+ .body(embeddingRequest)
+ .retrieve()
+ .toEntity(new ParameterizedTypeReference<>() {
+
+ });
+ }
+
+ /**
+ * Creates a streaming chat response for the given chat conversation.
+ * @param chatRequest The chat completion request. Must have the stream property set
+ * to true.
+ * @return Returns a {@link Flux} stream from chat completion chunks.
+ */
+ public Flux chatCompletionStream(ChatCompletionRequest chatRequest) {
+
+ Assert.notNull(chatRequest, "The request body can not be null.");
+ Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true.");
+
+ return this.webClient.post()
+ .uri("v2/chat")
+ .body(Mono.just(chatRequest), ChatCompletionRequest.class)
+ .retrieve()
+ .bodyToFlux(String.class)
+ .takeUntil(SSE_DONE_PREDICATE)
+ .filter(SSE_DONE_PREDICATE.negate())
+ .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class))
+ .groupBy(chunk -> chunk.id() != null ? chunk.id() : "no-id")
+ .flatMap(group -> group.reduce(new ChatCompletionChunk(null, null, null, null), this.chunkMerger::merge)
+ .filter(chunk -> EventType.MESSAGE_END.value.equals(chunk.type())
+ || (chunk.delta() != null && chunk.delta().finishReason() != null)))
+ .map(this.chunkMerger::sanitizeToolCalls)
+ .filter(this.chunkMerger::hasValidToolCallsOnly)
+ .filter(Objects::nonNull);
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ /**
+ * Builder for creating CohereApi instances.
+ */
+ public static class Builder {
+
+ private String baseUrl = DEFAULT_BASE_URL;
+
+ private String apiKey;
+
+ private RestClient.Builder restClientBuilder = RestClient.builder();
+
+ private WebClient.Builder webClientBuilder = WebClient.builder();
+
+ private ResponseErrorHandler responseErrorHandler = RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER;
+
+ public Builder baseUrl(String baseUrl) {
+ this.baseUrl = baseUrl;
+ return this;
+ }
+
+ public Builder apiKey(String apiKey) {
+ this.apiKey = apiKey;
+ return this;
+ }
+
+ public Builder restClientBuilder(RestClient.Builder restClientBuilder) {
+ this.restClientBuilder = restClientBuilder;
+ return this;
+ }
+
+ public Builder webClientBuilder(WebClient.Builder webClientBuilder) {
+ this.webClientBuilder = webClientBuilder;
+ return this;
+ }
+
+ public Builder responseErrorHandler(ResponseErrorHandler responseErrorHandler) {
+ this.responseErrorHandler = responseErrorHandler;
+ return this;
+ }
+
+ public CohereApi build() {
+ Assert.hasText(this.apiKey, "Cohere API key must be set");
+ Assert.hasText(this.baseUrl, "Cohere base URL must be set");
+ Assert.notNull(this.restClientBuilder, "RestClient.Builder must not be null");
+ Assert.notNull(this.webClientBuilder, "WebClient.Builder must not be null");
+ Assert.notNull(this.responseErrorHandler, "ResponseErrorHandler must not be null");
+
+ return new CohereApi(this.baseUrl, this.apiKey, this.restClientBuilder, this.webClientBuilder,
+ this.responseErrorHandler);
+ }
+
+ }
+
+ /**
+ * List of well-known Cohere chat models.
+ *
+ * @see Cohere Models Overview
+ */
+ public enum ChatModel implements ChatModelDescription {
+
+ COMMAND_A("command-a-03-2025"),
+
+ COMMAND_A_REASONING("command-a-reasoning-08-2025"),
+
+ COMMAND_A_TRANSLATE("command-a-translate-08-2025"),
+
+ COMMAND_A_VISION("command-a-vision-07-2025"),
+
+ COMMAND_A_R7B("command-r7b-12-2024"),
+
+ COMMAND_R_PLUS("command-r-plus-08-2024"),
+
+ COMMAND_R("command-r-08-2024");
+
+ private final String value;
+
+ ChatModel(String value) {
+ this.value = value;
+ }
+
+ public String getValue() {
+ return this.value;
+ }
+
+ @Override
+ public String getName() {
+ return this.value;
+ }
+
+ }
+
+ /**
+ * Usage statistics.
+ */
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ public record Usage(@JsonProperty("billedUnits") BilledUnits billedUnits, @JsonProperty("tokens") Tokens tokens,
+ @JsonProperty("cached_tokens") Integer cachedTokens) {
+ /**
+ * Bille units
+ *
+ * @param inputTokens The number of billed input tokens.
+ * @param outputTokens The number of billed output tokens.
+ * @param searchUnits The number of billed search units.
+ * @param classifications The number of billed classifications units.
+ */
+ public record BilledUnits(@JsonProperty("input_tokens") Integer inputTokens,
+ @JsonProperty("output_tokens") Integer outputTokens, @JsonProperty("search_units") Double searchUnits,
+ @JsonProperty("classifications") Double classifications) {
+ }
+
+ /**
+ * The Tokens
+ *
+ * @param inputTokens The number of tokens used as input to the model.
+ * @param outputTokens The number of tokens produced by the model.
+ */
+ public record Tokens(@JsonProperty("input_tokens") Integer inputTokens,
+ @JsonProperty("output_tokens") Integer outputTokens) {
+ }
+ }
+
+ /**
+ * Creates a model request for chat conversation.
+ *
+ * @param model The name of a compatible Cohere model or the ID of a fine-tuned model.
+ * @param messages The prompt(s) to generate completions for, encoded as a list of
+ * dict with role and rawContent. The first prompt role should be user or system.
+ * @param tools A list of tools the model may call. Currently, only functions are
+ * supported as a tool. Use this to provide a list of functions the model may generate
+ * JSON inputs for.
+ * @param documents A list of relevant documents that the model can cite to generate a
+ * more accurate reply. Each document is either a string or document object with
+ * rawContent and metadata.
+ * @param citationOptions Options for controlling citation generation.
+ * @param responseFormat An object specifying the format or schema that the model must
+ * output. Setting to { "type": "json_object" } enables JSON mode, which guarantees
+ * the message the model generates is valid JSON. Setting to { "type": "json_object" ,
+ * "json_schema": schema} allows you to ensure the model provides an answer in a very
+ * specific JSON format by supplying a clear JSON schema.
+ * @param safetyMode Safety modes are not yet configurable in combination with tools,
+ * tool_results and documents parameters.
+ * @param maxTokens The maximum number of tokens to generate in the completion. The
+ * token count of your prompt plus max_tokens cannot exceed the model's context
+ * length.
+ * @param stopSequences A list of tokens that the model should stop generating after.
+ * If set,
+ * @param temperature What sampling temperature to use, between 0.0 and 1.0. Higher
+ * values like 0.8 will make the output more random, while lower values like 0.2 will
+ * make it more focused and deterministic. We generally recommend altering this or p
+ * but not both.
+ * @param seed If specified, the backend will make a best effort to sample tokens
+ * deterministically, such that repeated requests with the same seed and parameters
+ * should return the same result. However, determinism cannot be totally guaranteed.
+ * @param frequencyPenalty Number between 0.0 and 1.0. Used to reduce repetitiveness
+ * of generated tokens. The higher the value, the stronger a penalty is applied to
+ * previously present tokens, proportional to how many times they have already
+ * appeared in the prompt or prior generation.
+ * @param presencePenalty min value of 0.0, max value of 1.0. Used to reduce
+ * repetitiveness of generated tokens. Similar to frequency_penalty, except that this
+ * penalty is applied equally to all tokens that have already appeared, regardless of
+ * their exact frequencies.
+ * @param stream When true, the response will be a SSE stream of events. The final
+ * event will contain the complete response, and will have an event_type of
+ * "stream-end".
+ * @param k Ensures that only the top k most likely tokens are considered for
+ * generation at each step. When k is set to 0, k-sampling is disabled. Defaults to 0,
+ * min value of 0, max value of 500.
+ * @param p Ensures that only the most likely tokens, with total probability mass of
+ * p, are considered for generation at each step. If both k and p are enabled, p acts
+ * after k. Defaults to 0.75. min value of 0.01, max value of 0.99.
+ * @param logprobs Defaults to false. When set to true, the log probabilities of the
+ * generated tokens will be included in the response.
+ * @param toolChoice Used to control whether or not the model will be forced to use a
+ * tool when answering. When REQUIRED is specified, the model will be forced to use at
+ * least one of the user-defined tools, and the tools parameter must be passed in the
+ * request. When NONE is specified, the model will be forced not to use one of the
+ * specified tools, and give a direct response. If tool_choice isn’t specified, then
+ * the model is free to choose whether to use the specified tools or not.
+ * @param strictTools When set to true, tool calls in the Assistant message will be
+ * forced to follow the tool definition strictly.
+ */
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ public record ChatCompletionRequest(@JsonProperty("model") String model,
+ @JsonProperty("messages") List messages,
+ @JsonProperty("tools") List tools, @JsonProperty("documents") List documents,
+ @JsonProperty("citation_options") CitationOptions citationOptions,
+ @JsonProperty("response_format") ResponseFormat responseFormat,
+ @JsonProperty("safety_mode") SafetyMode safetyMode, @JsonProperty("max_tokens") Integer maxTokens,
+ @JsonProperty("stop_sequences") List stopSequences, @JsonProperty("temperature") Double temperature,
+ @JsonProperty("seed") Integer seed, @JsonProperty("frequency_penalty") Double frequencyPenalty,
+ @JsonProperty("stream") Boolean stream, @JsonProperty("k") Integer k, @JsonProperty("p") Double p,
+ @JsonProperty("logprobs") Boolean logprobs, @JsonProperty("tool_choice") ToolChoice toolChoice,
+ @JsonProperty("strict_tools") Boolean strictTools,
+ @JsonProperty("presence_penalty") Double presencePenalty) {
+
+ /**
+ * Shortcut constructor for a chat completion request with the given messages and
+ * model.
+ * @param messages The prompt(s) to generate completions for, encoded as a list of
+ * dict with role and rawContent. The first prompt role should be user or system.
+ * @param model ID or name of the model to use.
+ */
+ public ChatCompletionRequest(List messages, String model) {
+ this(model, messages, null, null, new CitationOptions(CitationMode.FAST), null, SafetyMode.CONTEXTUAL, null,
+ null, 0.3, null, null, false, 0, 0.75, false, null, false, null);
+ }
+
+ /**
+ * Shortcut constructor for a chat completion request with the given messages,
+ * model and temperature.
+ * @param messages The prompt(s) to generate completions for, encoded as a list of
+ * dict with role and rawContent. The first prompt role should be user or system.
+ * @param model ID or model of the model to use.
+ * @param temperature What sampling temperature to use, between 0.0 and 1.0.
+ * @param stream Whether to stream back partial progress. If set, tokens will be
+ * sent
+ */
+ public ChatCompletionRequest(List messages, String model, Double temperature,
+ boolean stream) {
+ this(model, messages, null, null, new CitationOptions(CitationMode.FAST), null, SafetyMode.CONTEXTUAL, null,
+ null, temperature, null, null, stream, 0, 0.75, false, null, false, null);
+ }
+
+ /**
+ * Shortcut constructor for a chat completion request with the given messages,
+ * model and temperature.
+ * @param messages The prompt(s) to generate completions for, encoded as a list of
+ * dict with role and rawContent. The first prompt role should be user or system.
+ * @param model ID of the model to use.
+ * @param temperature What sampling temperature to use, between 0.0 and 1.0.
+ *
+ */
+ public ChatCompletionRequest(List messages, String model, Double temperature) {
+ this(model, messages, null, null, new CitationOptions(CitationMode.FAST), null, SafetyMode.CONTEXTUAL, null,
+ null, temperature, null, null, false, 0, 0.75, false, null, false, null);
+ }
+
+ /**
+ * Shortcut constructor for a chat completion request with the given messages,
+ * model, tools and tool choice. Streaming is set to false, temperature to 0.8 and
+ * all other parameters are null.
+ * @param messages A list of messages comprising the conversation so far.
+ * @param model ID of the model to use.
+ * @param tools A list of tools the model may call. Currently, only functions are
+ * supported as a tool.
+ * @param toolChoice Controls which (if any) function is called by the model.
+ */
+ public ChatCompletionRequest(List messages, String model, List tools,
+ ToolChoice toolChoice) {
+ this(model, messages, tools, null, new CitationOptions(CitationMode.FAST), null, SafetyMode.CONTEXTUAL,
+ null, null, 0.75, null, null, false, 0, 0.75, false, toolChoice, false, null);
+ }
+
+ /**
+ * Shortcut constructor for a chat completion request with the given messages and
+ * stream.
+ */
+ public ChatCompletionRequest(List messages, Boolean stream) {
+ this(null, messages, null, null, new CitationOptions(CitationMode.FAST), null, SafetyMode.CONTEXTUAL, null,
+ null, 0.75, null, null, stream, 0, 0.75, false, null, false, null);
+ }
+
+ /**
+ * An object specifying the format that the model must output.
+ *
+ * @param type Must be one of 'text' or 'json_object'.
+ * @param jsonSchema A specific JSON schema to match, if 'type' is 'json_object'.
+ */
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ public record ResponseFormat(@JsonProperty("type") String type,
+ @JsonProperty("json_schema") Map jsonSchema) {
+ }
+
+ /**
+ * Specifies a tool the model should use
+ */
+ public enum ToolChoice {
+
+ REQUIRED, NONE
+
+ }
+
+ }
+
+ /**
+ * Message comprising the conversation. A message from the assistant role can contain
+ * text and tool call information.
+ *
+ * @param role The role of the messages author. Could be one of the {@link Role} types
+ * "assistant".
+ * @param toolCalls The tool calls generated by the model, such as function calls.
+ * Applicable only for {@link Role#ASSISTANT} role and null otherwise.
+ * @param toolPlan A chain-of-thought style reflection and plan that the model
+ * generates when working with Tools.
+ * @param rawContent The contents of the message. Can be either a {@link MediaContent}
+ * or a {@link MessageContent}.
+ * @param citations Tool call that this message is responding to. Only applicable for
+ * the {@link ChatCompletionFinishReason#TOOL_CALL} role and null otherwise.
+ */
+ public record ChatCompletionMessage(@JsonProperty("content") Object rawContent, @JsonProperty("role") Role role,
+ @JsonProperty("tool_plan") String toolPlan,
+ @JsonFormat(
+ with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) @JsonProperty("tool_calls") List toolCalls,
+ @JsonFormat(
+ with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) @JsonProperty("citations") List citations,
+ @JsonProperty("tool_call_id") String toolCallId) {
+
+ public ChatCompletionMessage(Object content, Role role) {
+ this(content, role, null, null, null, null);
+ }
+
+ public ChatCompletionMessage(Object content, Role role, List toolCalls) {
+ this(content, role, null, toolCalls, null, null);
+ }
+
+ public ChatCompletionMessage(Object content, Role role, List toolCalls, String toolPlan) {
+ this(content, role, toolPlan, toolCalls, null, null);
+ }
+
+ public ChatCompletionMessage(Object content, Role role, String toolCallId) {
+ this(content, role, null, null, null, toolCallId);
+ }
+
+ /**
+ * Get message content as String.
+ */
+ public String content() {
+ if (this.rawContent == null) {
+ return null;
+ }
+ if (this.rawContent instanceof String text) {
+ return text;
+ }
+ throw new IllegalStateException("The content is not a string!");
+ }
+
+ /**
+ * An array of rawContent parts with a defined type. Each MediaContent can be of
+ * either "text" or "image_url" type. Only one option allowed.
+ *
+ * @param type Content type, each can be of type text or image_url.
+ * @param text The text rawContent of the message.
+ * @param imageUrl The image rawContent of the message.
+ */
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ public record MediaContent(@JsonProperty("type") String type, @JsonProperty("text") String text,
+ @JsonProperty("image_url") ImageUrl imageUrl) {
+
+ /**
+ * Shortcut constructor for a text rawContent.
+ * @param text The text rawContent of the message.
+ */
+ public MediaContent(String text) {
+ this("text", text, null);
+ }
+
+ /**
+ * Shortcut constructor for an image rawContent.
+ * @param imageUrl The image rawContent of the message.
+ */
+ public MediaContent(ImageUrl imageUrl) {
+ this("image_url", null, imageUrl);
+ }
+
+ /**
+ * The level of detail for processing the image.
+ */
+ public enum DetailLevel {
+
+ @JsonProperty("low")
+ LOW,
+
+ @JsonProperty("high")
+ HIGH,
+
+ @JsonProperty("auto")
+ AUTO
+
+ }
+
+ /**
+ * Shortcut constructor for an image rawContent.
+ *
+ * @param url Either a URL of the image or the base64 encoded image data. The
+ * base64 encoded image data must have a special prefix in the following
+ * format: "data:{mimetype};base64,{base64-encoded-image-data}".
+ * @param detail The level of detail for processing the image. Can be "low",
+ * "high", or "auto". Defaults to "auto" if not specified.
+ */
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ public record ImageUrl(@JsonProperty("url") String url, @JsonProperty("detail") DetailLevel detail) {
+
+ public ImageUrl(String url) {
+ this(url, DetailLevel.AUTO);
+ }
+
+ }
+ }
+
+ /**
+ * Message rawContent that can be either a text or a value.
+ *
+ * @param type The type of the message rawContent, such as "text" or "thinking".
+ * @param text The text rawContent of the message.
+ * @param value The value of the thinking, which can be any object.
+ */
+ public record MessageContent(@JsonProperty("type") String type, @JsonProperty("text") String text,
+ @JsonProperty("value") Object value) {
+ }
+
+ /**
+ * The role of the author of this message.
+ */
+ public enum Role {
+
+ /**
+ * User message.
+ */
+ @JsonProperty("user")
+ USER,
+ /**
+ * Assistant message.
+ */
+ @JsonProperty("assistant")
+ ASSISTANT,
+ /**
+ * System message.
+ */
+ @JsonProperty("system")
+ SYSTEM,
+ /**
+ * Tool message.
+ */
+ @JsonProperty("tool")
+ TOOL
+
+ }
+
+ /**
+ * The relevant tool call.
+ *
+ * @param id The ID of the tool call. This ID must be referenced when you submit
+ * the tool outputs in using the Submit tool outputs to run endpoint.
+ * @param type The type of tool call the output is required for. For now, this is
+ * always function.
+ * @param function The function definition.
+ * @param index The index of the tool call in the list of tool calls.
+ */
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ public record ToolCall(@JsonProperty("id") String id, @JsonProperty("type") String type,
+ @JsonProperty("function") ChatCompletionFunction function, @JsonProperty("index") Integer index) {
+ }
+
+ /**
+ * The function definition.
+ *
+ * @param name The name of the function.
+ * @param arguments The arguments that the model expects you to pass to the
+ * function.
+ */
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ public record ChatCompletionFunction(@JsonProperty("name") String name,
+ @JsonProperty("arguments") String arguments) {
+ }
+
+ public record ChatCompletionCitation(
+ /**
+ * Start index of the cited snippet in the original source text.
+ */
+ @JsonProperty("start") Integer start,
+ /**
+ * End index of the cited snippet in the original source text.
+ */
+ @JsonProperty("end") Integer end,
+ /**
+ * Text snippet that is being cited.
+ */
+ @JsonProperty("text") String text, @JsonProperty("sources") List sources,
+ @JsonProperty("type") Type type) {
+ /**
+ * The type of citation which indicates what part of the response the citation
+ * is for.
+ */
+ public enum Type {
+
+ TEXT_CONTENT, PLAN
+
+ }
+
+ /**
+ * @param type Tool or A document source object containing the unique
+ * identifier of the document and the document itself.
+ * @param id The unique identifier of the document
+ * @param toolOutput map from strings to any Optional if type == tool
+ * @param document map from strings to any Optional if type == document
+ */
+ public record Source(@JsonProperty("type") String type, @JsonProperty("id") String id,
+ @JsonProperty("tool_output") Map toolOutput,
+ @JsonProperty("document") Map document) {
+ }
+ }
+
+ public record Provider(@JsonProperty("content") List content, @JsonProperty("role") Role role,
+ @JsonProperty("tool_plan") String toolPlan, @JsonProperty("tool_calls") List toolCalls,
+ @JsonProperty("citations") List citations) {
+ }
+ }
+
+ /**
+ * Used to select the safety instruction inserted into the prompt. Defaults to
+ * CONTEXTUAL. When OFF is specified, the safety instruction will be omitted. Safety
+ * modes are not yet configurable in combination with tools, tool_results and
+ * documents parameters. Note: This parameter is only compatible newer Cohere models,
+ * starting with Command R 08-2024 and Command R+ 08-2024. Note: command-r7b-12-2024
+ * and newer models only support "CONTEXTUAL" and "STRICT" modes.
+ */
+ public enum SafetyMode {
+
+ CONTEXTUAL, STRICT, OFF
+
+ }
+
+ /**
+ * Options for controlling citation generation. Defaults to "accurate". Dictates the
+ * approach taken to generating citations as part of the RAG flow by allowing the user
+ * to specify whether they want "accurate" results, "fast" results or no results.
+ * Note: command-r7b-12-2024 and command-a-03-2025 only support "fast" and "off"
+ * modes. The default is "fast".
+ */
+ public record CitationOptions(@JsonProperty("mode") CitationMode mode) {
+ }
+
+ /**
+ * Options for controlling citation generation. Defaults to "accurate". Dictates the
+ * approach taken to generating citations as part of the RAG flow by allowing the user
+ * to specify whether they want "accurate" results, "fast" results or no results.
+ * Note: command-r7b-12-2024 and command-a-03-2025 only support "fast" and "off"
+ * modes. The default is "fast".
+ */
+ public enum CitationMode {
+
+ FAST, ACCURATE, OFF
+
+ }
+
+ /**
+ * relevant documents that the model can cite to generate a more accurate reply. Each
+ * document is either a string or document object with rawContent and metadata.
+ *
+ * @param id An optional Unique identifier for this document which will be referenced
+ * in citations. If not provided an ID will be automatically generated.
+ * @param data A relevant document that the model can cite to generate a more accurate
+ * reply. Each document is a string-any dictionary.
+ */
+ public record Document(@JsonProperty("id") String id, @JsonProperty("data") String data) {
+ }
+
+ /**
+ * Represents a tool the model may call. Currently, only functions are supported as a
+ * tool.
+ */
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ public static class FunctionTool {
+
+ // The type of the tool. Currently, only 'function' is supported.
+ @JsonProperty("type")
+ Type type = Type.FUNCTION;
+
+ // The function definition.
+ @JsonProperty("function")
+ Function function;
+
+ public FunctionTool() {
+
+ }
+
+ /**
+ * Create a tool of type 'function' and the given function definition.
+ * @param function function definition.
+ */
+ public FunctionTool(Function function) {
+ this(Type.FUNCTION, function);
+ }
+
+ public FunctionTool(Type type, Function function) {
+ this.type = type;
+ this.function = function;
+ }
+
+ public Type getType() {
+ return this.type;
+ }
+
+ public Function getFunction() {
+ return this.function;
+ }
+
+ public void setType(Type type) {
+ this.type = type;
+ }
+
+ public void setFunction(Function function) {
+ this.function = function;
+ }
+
+ /**
+ * Create a tool of type 'function' and the given function definition.
+ */
+ public enum Type {
+
+ /**
+ * Function tool type.
+ */
+ @JsonProperty("function")
+ FUNCTION
+
+ }
+
+ /**
+ * Function definition.
+ */
+ public static class Function {
+
+ @JsonProperty("description")
+ private String description;
+
+ @JsonProperty("name")
+ private String name;
+
+ @JsonProperty("parameters")
+ private Map parameters;
+
+ @JsonIgnore
+ private String jsonSchema;
+
+ private Function() {
+
+ }
+
+ /**
+ * Create tool function definition.
+ * @param description A description of what the function does, used by the
+ * model to choose when and how to call the function.
+ * @param name The name of the function to be called. Must be a-z, A-Z, 0-9,
+ * or contain underscores and dashes, with a maximum length of 64.
+ * @param parameters The parameters the functions accepts, described as a JSON
+ * Schema object. To describe a function that accepts no parameters, provide
+ * the value {"type": "object", "properties": {}}.
+ */
+ public Function(String description, String name, Map parameters) {
+ this.description = description;
+ this.name = name;
+ this.parameters = parameters;
+ }
+
+ /**
+ * Create tool function definition.
+ * @param description tool function description.
+ * @param name tool function name.
+ * @param jsonSchema tool function schema as json.
+ */
+ public Function(String description, String name, String jsonSchema) {
+ this(description, name, ModelOptionsUtils.jsonToMap(jsonSchema));
+ }
+
+ public String getDescription() {
+ return this.description;
+ }
+
+ public String getName() {
+ return this.name;
+ }
+
+ public Map getParameters() {
+ return this.parameters;
+ }
+
+ public void setDescription(String description) {
+ this.description = description;
+ }
+
+ public void setName(String name) {
+ this.name = name;
+ }
+
+ public void setParameters(Map parameters) {
+ this.parameters = parameters;
+ }
+
+ public String getJsonSchema() {
+ return this.jsonSchema;
+ }
+
+ public void setJsonSchema(String jsonSchema) {
+ this.jsonSchema = jsonSchema;
+ if (jsonSchema != null) {
+ this.parameters = ModelOptionsUtils.jsonToMap(jsonSchema);
+ }
+ }
+
+ }
+
+ }
+
+ /**
+ * Represents a chat completion response returned by model, based on the provided
+ * input.
+ *
+ * @param id A unique identifier for the chat completion.
+ * @param finishReason The reason the model stopped generating tokens.
+ * @param message A chat completion message generated by streamed model responses.
+ * @param logprobs Log probability information for the choice.
+ * @param usage Usage statistics for the completion request.
+ */
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ public record ChatCompletion(@JsonProperty("id") String id,
+ @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason,
+ @JsonProperty("message") ChatCompletionMessage.Provider message,
+ @JsonProperty("logprobs") LogProbs logprobs, @JsonProperty("usage") Usage usage) {
+ }
+
+ /**
+ * The reason the model stopped generating tokens.
+ */
+ public enum ChatCompletionFinishReason {
+
+ /**
+ * The model finished sending a complete message.
+ */
+ COMPLETE,
+
+ /**
+ * One of the provided stop_sequence entries was reached in the model’s
+ * generation.
+ */
+ STOP_SEQUENCE,
+
+ /**
+ * The number of generated tokens exceeded the model’s context length or the value
+ * specified via the max_tokens parameter.
+ */
+ MAX_TOKENS,
+
+ /**
+ * The model generated a Tool Call and is expecting a Tool Message in return
+ */
+ TOOL_CALL,
+
+ /**
+ * The model called a tool.
+ */
+ @JsonProperty("tool_calls")
+ TOOL_CALLS,
+
+ /**
+ * The generation failed due to an internal error
+ */
+ ERROR
+
+ }
+
+ /**
+ * Log probability information
+ *
+ * @param tokenIds The token ids of each token used to construct the text chunk.
+ * @param text The text chunk for which the log probabilities was calculated.
+ * @param logprobs The log probability of each token used to construct the text chunk.
+ */
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ public record LogProbs(@JsonProperty("token_ids") List tokenIds, @JsonProperty("text") String text,
+ @JsonProperty("logprobs") List logprobs) {
+
+ }
+
+ /**
+ * Helper factory that creates a tool_choice of type 'REQUIRED', 'NONE' or selected
+ * function by name.
+ */
+ public static class ToolChoiceBuilder {
+
+ public static final String NONE = "NONE";
+
+ public static final String REQUIRED = "REQUIRED";
+
+ /**
+ * Specifying a particular function forces the model to call that function.
+ */
+ public static Object FUNCTION(String functionName) {
+ return Map.of("type", "function", "function", Map.of("name", functionName));
+ }
+
+ }
+
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ @JsonIgnoreProperties(ignoreUnknown = true)
+ public record ChatCompletionChunk(
+ // @formatter:off
+ @JsonProperty("id") String id,
+ @JsonProperty("type") String type,
+ @JsonProperty("index") Integer index,
+ @JsonProperty("delta") ChunkDelta delta) {
+ // @formatter:on
+
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ @JsonIgnoreProperties(ignoreUnknown = true)
+ public record ChunkDelta(
+ // @formatter:off
+ @JsonProperty("message") ChatCompletionMessage message,
+ @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason,
+ @JsonProperty("usage") Usage usage) {
+ // @formatter:on
+ }
+
+ }
+
+ /**
+ * List of well-known Cohere embedding models.
+ *
+ * @see Cohere Models Overview
+ */
+ public enum EmbeddingModel {
+
+ // @formatter:off
+
+ /**
+ * A model that allows for text and images to be classified or turned into embeddings
+ * dimensional - [256, 512, 1024, 1536 (default)]
+ */
+ EMBED_V4("embed-v4.0"),
+ /**
+ * Embed v3 Multilingual model for text embeddings.
+ * Produces 1024-dimensional embeddings suitable for multilingual semantic search,
+ * clustering, and other text similarity tasks.
+ */
+ EMBED_MULTILINGUAL_V3("embed-multilingual-v3.0"),
+
+ /**
+ * Embed v3 English model for text embeddings.
+ * Produces 1024-dimensional embeddings optimized for English text.
+ */
+ EMBED_ENGLISH_V3("embed-english-v3.0"),
+
+ /**
+ * Embed v3 Multilingual Light model.
+ * Smaller and faster variant with 1024 dimensions.
+ */
+ EMBED_MULTILINGUAL_LIGHT_V3("embed-multilingual-light-v3.0"),
+
+ /**
+ * Embed v3 English Light model.
+ * Smaller and faster English-only variant with 1024 dimensions.
+ */
+ EMBED_ENGLISH_LIGHT_V3("embed-english-light-v3.0");
+ // @formatter:on
+
+ private final String value;
+
+ EmbeddingModel(String value) {
+ this.value = value;
+ }
+
+ public String getValue() {
+ return this.value;
+ }
+
+ }
+
+ /**
+ * Embedding type
+ */
+ public enum EmbeddingType {
+
+ /**
+ * Use this when you want to get back the default float embeddings. Supported with
+ * all Embed models.
+ */
+ @JsonProperty("float")
+ FLOAT,
+
+ /**
+ * Use this when you want to get back signed int8 embeddings. Supported with Embed
+ * v3.0 and newer Embed models.
+ */
+ @JsonProperty("int8")
+ INT8,
+
+ /**
+ * Use this when you want to get back unsigned int8 embeddings. Supported with
+ * Embed v3.0 and newer Embed models.
+ */
+ @JsonProperty("uint8")
+ UINT8,
+
+ /**
+ * Use this when you want to get back signed binary embeddings. Supported with
+ * Embed v3.0 and newer Embed models.
+ */
+ @JsonProperty("binary")
+ BINARY,
+
+ /**
+ * Use this when you want to get back unsigned binary embeddings. Supported with
+ * Embed v3.0 and newer Embed models.
+ */
+ @JsonProperty("ubinary")
+ UBINARY,
+
+ /**
+ * Use this when you want to get back base64 embeddings. Supported with Embed v3.0
+ * and newer Embed models.
+ */
+ @JsonProperty("base64")
+ BASE64
+
+ }
+
+ /**
+ * Embedding request.
+ *
+ * @param texts An array of strings to embed.
+ * @param images An array of images to embed as data URIs.
+ * @param model The model to use for embedding.
+ * @param inputType The type of input (search_document, search_query, classification,
+ * clustering, image).
+ * @param embeddingTypes The types of embeddings to return (float, int8, uint8,
+ * binary, ubinary).
+ * @param truncate How to handle inputs longer than the maximum token length (NONE,
+ * START, END).
+ * @param Type of the input (String or List of tokens).
+ */
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ public record EmbeddingRequest(
+ // @formatter:off
+ @JsonProperty("texts") List texts,
+ @JsonProperty("images") List images,
+ @JsonProperty("model") String model,
+ @JsonProperty("input_type") InputType inputType,
+ @JsonProperty("embedding_types") List embeddingTypes,
+ @JsonProperty("truncate") Truncate truncate) {
+ // @formatter:on
+
+ public static Builder builder() {
+ return new Builder<>();
+ }
+
+ public static final class Builder {
+
+ private String model = EmbeddingModel.EMBED_V4.getValue();
+
+ private List texts;
+
+ private List images;
+
+ private InputType inputType = InputType.SEARCH_DOCUMENT;
+
+ private List embeddingTypes = List.of(EmbeddingType.FLOAT);
+
+ private Truncate truncate = Truncate.END;
+
+ public Builder model(String model) {
+ this.model = model;
+ return this;
+ }
+
+ public Builder texts(Object raw) {
+ if (raw == null) {
+ this.texts = null;
+ return this;
+ }
+
+ if (raw instanceof List> list) {
+ this.texts = (List) list;
+ }
+ else {
+ this.texts = List.of((T) raw);
+ }
+ return this;
+ }
+
+ public Builder images(List images) {
+ this.images = images;
+ return this;
+ }
+
+ public Builder inputType(InputType inputType) {
+ this.inputType = inputType;
+ return this;
+ }
+
+ public Builder embeddingTypes(List embeddingTypes) {
+ this.embeddingTypes = embeddingTypes;
+ return this;
+ }
+
+ public Builder truncate(Truncate truncate) {
+ this.truncate = truncate;
+ return this;
+ }
+
+ public EmbeddingRequest build() {
+ return new EmbeddingRequest<>(this.texts, this.images, this.model, this.inputType, this.embeddingTypes,
+ this.truncate);
+ }
+
+ }
+
+ /**
+ * Input type for embeddings.
+ */
+ public enum InputType {
+
+ // @formatter:off
+ @JsonProperty("search_document")
+ SEARCH_DOCUMENT,
+ @JsonProperty("search_query")
+ SEARCH_QUERY,
+ @JsonProperty("classification")
+ CLASSIFICATION,
+ @JsonProperty("clustering")
+ CLUSTERING,
+ @JsonProperty("image")
+ IMAGE
+ // @formatter:on
+
+ }
+
+ /**
+ * Truncation strategy for inputs longer than maximum token length.
+ */
+ public enum Truncate {
+
+ // @formatter:off
+ @JsonProperty("NONE")
+ NONE,
+ @JsonProperty("START")
+ START,
+ @JsonProperty("END")
+ END
+ // @formatter:on
+
+ }
+
+ }
+
+ /**
+ * Embedding response.
+ *
+ * @param id Unique identifier for the embedding request.
+ * @param embeddings The embeddings
+ * @param texts The texts that were embedded.
+ * @param responseType The type of response ("embeddings_floats" or
+ * "embeddings_by_type").
+ */
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ @JsonIgnoreProperties(ignoreUnknown = true)
+ public record EmbeddingResponse(
+ // @formatter:off
+ @JsonProperty("id") String id,
+ @JsonProperty("embeddings") Object embeddings,
+ @JsonProperty("texts") List texts,
+ @JsonProperty("response_type") String responseType) {
+ // @formatter:on
+
+ /**
+ * Extracts float embeddings from the response. Handles both response formats: -
+ * "embeddings_floats": embeddings is List<List<Double>> -
+ * "embeddings_by_type": embeddings is Map with "float" key containing
+ * List<List<Double>>
+ * @return List of float arrays representing the embeddings
+ */
+ @JsonIgnore
+ @SuppressWarnings("unchecked")
+ public List getFloatEmbeddings() {
+ if (this.embeddings == null) {
+ return List.of();
+ }
+
+ // Handle "embeddings_floats" format: embeddings is directly
+ // List>
+ if (this.embeddings instanceof List> embeddingsList) {
+ return embeddingsList.stream().map(embedding -> {
+ if (embedding instanceof List> embeddingVector) {
+ float[] floatArray = new float[embeddingVector.size()];
+ for (int i = 0; i < embeddingVector.size(); i++) {
+ Object value = embeddingVector.get(i);
+ floatArray[i] = (value instanceof Number number) ? number.floatValue() : 0f;
+ }
+ return floatArray;
+ }
+ return new float[0];
+ }).toList();
+ }
+
+ // Handle "embeddings_by_type" format: embeddings is Map
+ if (this.embeddings instanceof Map, ?> embeddingsMap) {
+ Object floatEmbeddings = embeddingsMap.get("float");
+ if (floatEmbeddings instanceof List> embeddingsList) {
+ return embeddingsList.stream().map(embedding -> {
+ if (embedding instanceof List> embeddingVector) {
+ float[] floatArray = new float[embeddingVector.size()];
+ for (int i = 0; i < embeddingVector.size(); i++) {
+ Object value = embeddingVector.get(i);
+ floatArray[i] = (value instanceof Number number) ? number.floatValue() : 0f;
+ }
+ return floatArray;
+ }
+ return new float[0];
+ }).toList();
+ }
+ }
+
+ return List.of();
+ }
+
+ }
+
+ public enum EventType {
+
+ MESSAGE_END("message-end"), CONTENT_START("content-start"), CONTENT_DELTA("content-delta"),
+ CONTENT_END("content-end"), TOOL_PLAN_DELTA("tool-plan-delta"), TOOL_CALL_START("tool-call-start"),
+ TOOL_CALL_DELTA("tool-call-delta"), CITATION_START("citation-start");
+
+ public final String value;
+
+ EventType(String value) {
+ this.value = value;
+ }
+
+ public String getValue() {
+ return this.value;
+ }
+
+ }
+
+}
diff --git a/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/api/CohereStreamFunctionCallingHelper.java b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/api/CohereStreamFunctionCallingHelper.java
new file mode 100644
index 00000000000..63a961ac972
--- /dev/null
+++ b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/api/CohereStreamFunctionCallingHelper.java
@@ -0,0 +1,296 @@
+/*
+ * Copyright 2023-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.api;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionChunk;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage.ChatCompletionFunction;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage.Role;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage.ToolCall;
+import org.springframework.ai.cohere.api.CohereApi.EventType;
+import org.springframework.util.ObjectUtils;
+
+/**
+ * Helper class for handling streaming function calling in Cohere API.
+ *
+ * @author Ricken Bazolo
+ */
+public class CohereStreamFunctionCallingHelper {
+
+ /**
+ * Merge the previous and current ChatCompletionChunk into a single one.
+ * @param previous the previous ChatCompletionChunk
+ * @param current the current ChatCompletionChunk
+ * @return the merged ChatCompletionChunk
+ */
+ public ChatCompletionChunk merge(ChatCompletionChunk previous, ChatCompletionChunk current) {
+
+ if (previous == null) {
+ return current;
+ }
+
+ if (current == null) {
+ return previous;
+ }
+
+ var previousDelta = previous.delta();
+ var currentDelta = current.delta();
+
+ ChatCompletionMessage previousMessage = previousDelta != null ? previousDelta.message() : null;
+ ChatCompletionMessage currentMessage = currentDelta != null ? currentDelta.message() : null;
+
+ Role role = previousMessage != null && previousMessage.role() != null ? previousMessage.role()
+ : (currentMessage != null ? currentMessage.role() : null);
+
+ String previousText = previousMessage != null ? extractTextFromRawContent(previousMessage.rawContent()) : "";
+
+ String currentText = currentMessage != null ? extractTextFromRawContent(currentMessage.rawContent()) : "";
+
+ String currentType = current.type();
+ String mergedText;
+ if (EventType.CONTENT_START.getValue().equals(currentType)) {
+ mergedText = currentText;
+ }
+ else if (EventType.CONTENT_END.getValue().equals(currentType)) {
+ mergedText = previousText;
+ }
+ else {
+ mergedText = previousText + currentText;
+ }
+
+ String previousPlan = previousMessage != null ? previousMessage.toolPlan() : null;
+ String currentPlan = currentMessage != null ? currentMessage.toolPlan() : null;
+
+ String mergedToolPlan = previousPlan;
+
+ if (EventType.TOOL_PLAN_DELTA.getValue().equals(current.type())) {
+ mergedToolPlan = mergeToolPlan(previousPlan, currentPlan);
+ }
+
+ List mergedToolCalls = mergeToolCalls(previous, current);
+
+ List citations = mergeCitations(previous, current);
+
+ ChatCompletionMessage mergedMessage = new ChatCompletionMessage(mergedText, role, mergedToolPlan,
+ mergedToolCalls, citations, null);
+
+ var finishReason = (currentDelta != null && currentDelta.finishReason() != null) ? currentDelta.finishReason()
+ : (previousDelta != null ? previousDelta.finishReason() : null);
+
+ var usage = (currentDelta != null && currentDelta.usage() != null) ? currentDelta.usage()
+ : (previousDelta != null ? previousDelta.usage() : null);
+
+ var mergedDelta = new ChatCompletionChunk.ChunkDelta(mergedMessage, finishReason, usage);
+
+ String id = current.id() != null ? current.id() : previous.id();
+ String type = current.type() != null ? current.type() : previous.type();
+ Integer index = current.index() != null ? current.index() : previous.index();
+
+ return new ChatCompletionChunk(id, type, index, mergedDelta);
+ }
+
+ public ChatCompletionChunk sanitizeToolCalls(ChatCompletionChunk chunk) {
+ if (chunk == null || chunk.delta() == null || chunk.delta().message() == null) {
+ return chunk;
+ }
+
+ ChatCompletionMessage msg = chunk.delta().message();
+ List toolCalls = msg.toolCalls();
+
+ if (toolCalls == null || toolCalls.isEmpty()) {
+ return chunk;
+ }
+
+ List cleaned = toolCalls.stream().filter(this::isValidToolCall).toList();
+
+ ChatCompletionChunk.ChunkDelta oldDelta = chunk.delta();
+
+ ChatCompletionMessage cleanedMsg = new ChatCompletionMessage(msg.rawContent(), msg.role(), msg.toolPlan(),
+ cleaned.isEmpty() ? null : cleaned, msg.citations(), null);
+
+ ChatCompletionChunk.ChunkDelta newDelta = new ChatCompletionChunk.ChunkDelta(cleanedMsg,
+ oldDelta.finishReason(), oldDelta.usage());
+
+ return new ChatCompletionChunk(chunk.id(), chunk.type(), chunk.index(), newDelta);
+ }
+
+ public boolean hasValidToolCallsOnly(ChatCompletionChunk c) {
+ if (c == null || c.delta() == null || c.delta().message() == null) {
+ return false;
+ }
+
+ ChatCompletionMessage message = c.delta().message();
+ List calls = message.toolCalls();
+
+ boolean hasValidToolCalls = calls != null && calls.stream().anyMatch(this::isValidToolCall);
+
+ boolean hasTextContent = message.rawContent() != null
+ && !extractTextFromRawContent(message.rawContent()).isEmpty();
+
+ boolean hasCitations = message.citations() != null && !message.citations().isEmpty();
+
+ return hasValidToolCalls || hasTextContent || hasCitations;
+ }
+
+ private boolean isValidToolCall(ToolCall toolCall) {
+ if (toolCall == null || toolCall.function() == null) {
+ return false;
+ }
+ ChatCompletionFunction chatCompletionFunction = toolCall.function();
+ String functionName = chatCompletionFunction.name();
+ String functionArguments = chatCompletionFunction.arguments();
+ return !ObjectUtils.isEmpty(functionName) && !ObjectUtils.isEmpty(functionArguments);
+ }
+
+ private String extractTextFromRawContent(Object rawContent) {
+ if (rawContent == null) {
+ return "";
+ }
+ if (rawContent instanceof Map, ?> map) {
+ Object text = map.get("text");
+ if (text != null) {
+ return text.toString();
+ }
+ }
+ if (rawContent instanceof List> list) {
+ StringBuilder sb = new StringBuilder();
+ for (Object item : list) {
+ if (item instanceof Map, ?> m) {
+ Object text = m.get("text");
+ if (text != null) {
+ sb.append(text);
+ }
+ }
+ else if (item instanceof String s) {
+ sb.append(s);
+ }
+ }
+ return sb.toString();
+ }
+ if (rawContent instanceof String s) {
+ return s;
+ }
+ return rawContent.toString();
+ }
+
+ private List mergeToolCalls(ChatCompletionChunk previous, ChatCompletionChunk current) {
+ ChatCompletionMessage previousMessage = previous != null && previous.delta() != null
+ ? previous.delta().message() : null;
+ ChatCompletionMessage currentMessage = current.delta() != null ? current.delta().message() : null;
+
+ List merged = ensureToolCallList(previousMessage != null ? previousMessage.toolCalls() : null);
+
+ String type = current.type();
+ Integer index = current.index();
+
+ if (index == null) {
+ return merged;
+ }
+
+ ToolCall existing = ensureToolCallAtIndex(merged, index);
+ ChatCompletionFunction existingFunction = existing.function() != null ? existing.function()
+ : new ChatCompletionFunction(null, "");
+
+ String id = existing.id();
+ String callType = existing.type();
+ String functionName = existingFunction.name();
+ String args = existingFunction.arguments() != null ? existingFunction.arguments() : "";
+
+ if (EventType.TOOL_CALL_START.getValue().equals(type) && currentMessage != null
+ && currentMessage.toolCalls() != null && !currentMessage.toolCalls().isEmpty()) {
+
+ ToolCall start = currentMessage.toolCalls().get(0);
+ ChatCompletionFunction startFunction = start.function() != null ? start.function()
+ : new ChatCompletionFunction(null, "");
+
+ id = start.id() != null ? start.id() : id;
+ callType = start.type() != null ? start.type() : callType;
+ functionName = startFunction.name() != null ? startFunction.name() : functionName;
+
+ }
+
+ if (EventType.TOOL_CALL_DELTA.getValue().equals(type) && currentMessage != null
+ && currentMessage.toolCalls() != null && !currentMessage.toolCalls().isEmpty()) {
+
+ ToolCall deltaCall = currentMessage.toolCalls().get(0);
+ ChatCompletionFunction deltaFunction = deltaCall.function();
+ if (deltaFunction != null && deltaFunction.arguments() != null) {
+ args = (args == null ? "" : args) + deltaFunction.arguments();
+ }
+ }
+
+ // tool-call-end
+ ChatCompletionFunction mergedFn = new ChatCompletionFunction(functionName, args);
+ ToolCall mergedCall = new ToolCall(id, callType, mergedFn, index);
+ merged.set(index, mergedCall);
+
+ return merged;
+ }
+
+ private String mergeToolPlan(final String previous, final String currentFragment) {
+ if (currentFragment == null || currentFragment.isEmpty()) {
+ return previous;
+ }
+ if (previous == null) {
+ return currentFragment;
+ }
+ return previous + currentFragment;
+ }
+
+ private List mergeCitations(final ChatCompletionChunk previous,
+ final ChatCompletionChunk current) {
+
+ ChatCompletionMessage previousMessage = previous != null && previous.delta() != null
+ ? previous.delta().message() : null;
+ ChatCompletionMessage currentMessage = current != null && current.delta() != null ? current.delta().message()
+ : null;
+
+ List merged = new ArrayList<>();
+
+ if (previousMessage != null && previousMessage.citations() != null) {
+ merged.addAll(previousMessage.citations());
+ }
+
+ if (current != null && EventType.CITATION_START.getValue().equals(current.type()) && currentMessage != null
+ && currentMessage.citations() != null) {
+ merged.addAll(currentMessage.citations());
+ }
+
+ return merged.isEmpty() ? null : merged;
+ }
+
+ private List ensureToolCallList(final List toolCalls) {
+ return (toolCalls != null) ? new ArrayList<>(toolCalls) : new ArrayList<>();
+ }
+
+ private ToolCall ensureToolCallAtIndex(final List toolCalls, final int index) {
+ while (toolCalls.size() <= index) {
+ toolCalls.add(new ToolCall(null, null, new ChatCompletionFunction("", ""), index));
+ }
+ ToolCall call = toolCalls.get(index);
+ if (call == null) {
+ call = new ToolCall(null, null, new ChatCompletionFunction("", ""), index);
+ toolCalls.set(index, call);
+ }
+ return call;
+ }
+
+}
diff --git a/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/chat/CohereChatModel.java b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/chat/CohereChatModel.java
new file mode 100644
index 00000000000..0e8e8382333
--- /dev/null
+++ b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/chat/CohereChatModel.java
@@ -0,0 +1,689 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.chat;
+
+import java.util.ArrayList;
+import java.util.Base64;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+import io.micrometer.observation.Observation;
+import io.micrometer.observation.ObservationRegistry;
+import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+import reactor.core.scheduler.Schedulers;
+
+import org.springframework.ai.chat.messages.AssistantMessage;
+import org.springframework.ai.chat.messages.Message;
+import org.springframework.ai.chat.messages.ToolResponseMessage;
+import org.springframework.ai.chat.messages.UserMessage;
+import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
+import org.springframework.ai.chat.metadata.ChatResponseMetadata;
+import org.springframework.ai.chat.metadata.DefaultUsage;
+import org.springframework.ai.chat.metadata.Usage;
+import org.springframework.ai.chat.model.ChatModel;
+import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.chat.model.Generation;
+import org.springframework.ai.chat.model.MessageAggregator;
+import org.springframework.ai.chat.observation.ChatModelObservationContext;
+import org.springframework.ai.chat.observation.ChatModelObservationConvention;
+import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
+import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
+import org.springframework.ai.chat.prompt.ChatOptions;
+import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.cohere.api.CohereApi;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletion;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage.Role;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage.ToolCall;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionRequest;
+import org.springframework.ai.cohere.api.CohereApi.FunctionTool;
+import org.springframework.ai.content.Media;
+import org.springframework.ai.model.ModelOptionsUtils;
+import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
+import org.springframework.ai.model.tool.ToolCallingChatOptions;
+import org.springframework.ai.model.tool.ToolCallingManager;
+import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
+import org.springframework.ai.model.tool.ToolExecutionResult;
+import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
+import org.springframework.ai.retry.RetryUtils;
+import org.springframework.ai.support.UsageCalculator;
+import org.springframework.ai.tool.definition.ToolDefinition;
+import org.springframework.core.retry.RetryTemplate;
+import org.springframework.http.ResponseEntity;
+import org.springframework.util.Assert;
+import org.springframework.util.CollectionUtils;
+import org.springframework.util.MimeType;
+
+/**
+ * Represents a Cohere Chat Model.
+ *
+ * @author Ricken Bazolo
+ */
+public class CohereChatModel implements ChatModel {
+
+ private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();
+
+ private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = ToolCallingManager.builder().build();
+
+ private final Logger logger = LoggerFactory.getLogger(getClass());
+
+ /**
+ * The default options used for the chat completion requests.
+ */
+ private final CohereChatOptions defaultOptions;
+
+ /**
+ * Low-level access to the Cohere API.
+ */
+ private final CohereApi cohereApi;
+
+ private final RetryTemplate retryTemplate;
+
+ /**
+ * Observation registry used for instrumentation.
+ */
+ private final ObservationRegistry observationRegistry;
+
+ private final ToolCallingManager toolCallingManager;
+
+ /**
+ * The tool execution eligibility predicate used to determine if a tool can be
+ * executed.
+ */
+ private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate;
+
+ /**
+ * Conventions to use for generating observations.
+ */
+ private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
+
+ public CohereChatModel(CohereApi cohereApi, CohereChatOptions defaultOptions, ToolCallingManager toolCallingManager,
+ RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
+ this(cohereApi, defaultOptions, toolCallingManager, retryTemplate, observationRegistry,
+ new DefaultToolExecutionEligibilityPredicate());
+ }
+
+ public CohereChatModel(CohereApi cohereApi, CohereChatOptions defaultOptions, ToolCallingManager toolCallingManager,
+ RetryTemplate retryTemplate, ObservationRegistry observationRegistry,
+ ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) {
+ Assert.notNull(cohereApi, "cohereApi cannot be null");
+ Assert.notNull(defaultOptions, "defaultOptions cannot be null");
+ Assert.notNull(toolCallingManager, "toolCallingManager cannot be null");
+ Assert.notNull(retryTemplate, "retryTemplate cannot be null");
+ Assert.notNull(observationRegistry, "observationRegistry cannot be null");
+ Assert.notNull(toolExecutionEligibilityPredicate, "toolExecutionEligibilityPredicate cannot be null");
+ this.cohereApi = cohereApi;
+ this.defaultOptions = defaultOptions;
+ this.toolCallingManager = toolCallingManager;
+ this.retryTemplate = retryTemplate;
+ this.observationRegistry = observationRegistry;
+ this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate;
+ }
+
+ public static ChatResponseMetadata from(ChatCompletion result) {
+ Assert.notNull(result, "Cohere ChatCompletion must not be null");
+ DefaultUsage usage = getDefaultUsage(result.usage());
+ return ChatResponseMetadata.builder().id(result.id()).usage(usage).build();
+ }
+
+ public static ChatResponseMetadata from(ChatCompletion result, Usage usage) {
+ Assert.notNull(result, "Cohere ChatCompletion must not be null");
+ return ChatResponseMetadata.builder().id(result.id()).usage(usage).build();
+ }
+
+ private static DefaultUsage getDefaultUsage(CohereApi.Usage usage) {
+ return new DefaultUsage(usage.tokens().inputTokens(), usage.tokens().outputTokens(), null, usage);
+ }
+
+ @Override
+ public ChatResponse call(Prompt prompt) {
+ Prompt requestPrompt = buildRequestPrompt(prompt);
+ return this.internalCall(requestPrompt, null);
+ }
+
+ @Override
+ public Flux stream(Prompt prompt) {
+ Prompt requestPrompt = buildRequestPrompt(prompt);
+ return this.internalStream(requestPrompt, null);
+ }
+
+ public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) {
+ return Flux.deferContextual(contextView -> {
+ var request = createRequest(prompt, true);
+
+ ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
+ .prompt(prompt)
+ .provider(CohereApi.PROVIDER_NAME)
+ .build();
+
+ Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
+ this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
+ this.observationRegistry);
+
+ observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
+
+ Flux completionChunks = RetryUtils.execute(this.retryTemplate,
+ () -> this.cohereApi.chatCompletionStream(request));
+
+ // For chunked responses, only the first chunk contains the role.
+ // The rest of the chunks with same ID share the same role.
+ ConcurrentHashMap roleMap = new ConcurrentHashMap<>();
+
+ // Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse
+ // the function call handling logic.
+ Flux chatResponse = completionChunks.map(this::toChatCompletion)
+ .filter(chatCompletion -> chatCompletion != null && chatCompletion.message() != null)
+ .switchMap(chatCompletion -> Mono.just(chatCompletion).map(completion -> {
+ try {
+ @SuppressWarnings("null")
+ String id = completion.id();
+ ChatCompletionMessage.Provider message = completion.message();
+
+ // Store the role for this completion ID
+ if (message.role() != null && id != null) {
+ roleMap.putIfAbsent(id, message.role().name());
+ }
+
+ List generations = message.content().stream().map(content -> {
+ Map metadata = Map.of("id", completion.id() != null ? completion.id() : "",
+ "role", completion.id() != null ? roleMap.getOrDefault(id, "") : "", "finishReason",
+ completion.finishReason() != null ? completion.finishReason().name() : "");
+ return buildGeneration(content, completion, metadata);
+ }).toList();
+
+ if (completion.usage() != null) {
+ DefaultUsage usage = getDefaultUsage(completion.usage());
+ Usage cumulativeUsage = UsageCalculator.getCumulativeUsage(usage, previousChatResponse);
+ return new ChatResponse(generations, from(completion, cumulativeUsage));
+ }
+ else {
+ return new ChatResponse(generations);
+ }
+ }
+ catch (Exception e) {
+ logger.error("Error processing chat completion", e);
+ return new ChatResponse(List.of());
+ }
+ }));
+
+ // @formatter:off
+ Flux chatResponseFlux = chatResponse.flatMap(response -> {
+ if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
+ return Flux.deferContextual(ctx -> {
+ ToolExecutionResult toolExecutionResult;
+ try {
+ ToolCallReactiveContextHolder.setContext(ctx);
+ toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
+ }
+ finally {
+ ToolCallReactiveContextHolder.clearContext();
+ }
+ if (toolExecutionResult.returnDirect()) {
+ // Return tool execution result directly to the client.
+ return Flux.just(ChatResponse.builder().from(response)
+ .generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
+ .build());
+ }
+ else {
+ // Send the tool execution result back to the model.
+ var chatOptions = CohereChatOptions.fromOptions(prompt.getOptions().copy());
+ return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), chatOptions),
+ response);
+ }
+ }).subscribeOn(Schedulers.boundedElastic());
+ }
+ else {
+ return Flux.just(response);
+ }
+ })
+ .doOnError(observation::error)
+ .doFinally(s -> observation.stop())
+ .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
+ // @formatter:on;
+
+ return new MessageAggregator().aggregate(chatResponseFlux, observationContext::setResponse);
+ });
+
+ }
+
+ private ChatCompletion toChatCompletion(CohereApi.ChatCompletionChunk chunk) {
+ if (chunk == null || chunk.delta() == null) {
+ return null;
+ }
+
+ CohereApi.ChatCompletionChunk.ChunkDelta delta = chunk.delta();
+ ChatCompletionMessage message = delta.message();
+
+ ChatCompletionMessage.Provider provider = null;
+ if (message != null) {
+
+ List content = extractMessageContent(message.rawContent());
+
+ provider = new ChatCompletionMessage.Provider(content, message.role(), message.toolPlan(),
+ message.toolCalls(), message.citations());
+ }
+
+ return new ChatCompletion(chunk.id(), delta.finishReason(), provider, null, delta.usage());
+ }
+
+ private List extractMessageContent(Object rawContent) {
+ if (rawContent == null) {
+ return List.of();
+ }
+
+ if (rawContent instanceof String text) {
+ return List.of(new ChatCompletionMessage.MessageContent("text", text, null));
+ }
+
+ if (rawContent instanceof List> list) {
+ List messageContents = new ArrayList<>();
+ for (Object item : list) {
+ if (item instanceof ChatCompletionMessage.MessageContent mc) {
+ messageContents.add(mc);
+ }
+ else if (item instanceof Map, ?> map) {
+ String type = (String) map.get("type");
+ String text = (String) map.get("text");
+ Object value = map.get("value");
+ messageContents.add(new ChatCompletionMessage.MessageContent(type, text, value));
+ }
+ }
+ return messageContents;
+ }
+
+ if (rawContent instanceof Map, ?> map) {
+ String type = (String) map.get("type");
+ String text = (String) map.get("text");
+ Object value = map.get("value");
+ return List.of(new ChatCompletionMessage.MessageContent(type != null ? type : "text", text, value));
+ }
+
+ return List.of();
+ }
+
+ Prompt buildRequestPrompt(Prompt prompt) {
+ // Process runtime options
+ CohereChatOptions runtimeOptions = null;
+ if (prompt.getOptions() != null) {
+ if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
+ runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class,
+ CohereChatOptions.class);
+ }
+ else {
+ runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
+ CohereChatOptions.class);
+ }
+ }
+
+ // Define request options by merging runtime options and default options
+ CohereChatOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions,
+ CohereChatOptions.class);
+
+ // Merge @JsonIgnore-annotated options explicitly since they are ignored by
+ // Jackson, used by ModelOptionsUtils.
+ if (runtimeOptions != null) {
+ requestOptions.setInternalToolExecutionEnabled(
+ ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(),
+ this.defaultOptions.getInternalToolExecutionEnabled()));
+ requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(),
+ this.defaultOptions.getToolNames()));
+ requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(),
+ this.defaultOptions.getToolCallbacks()));
+ requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(),
+ this.defaultOptions.getToolContext()));
+ }
+ else {
+ requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled());
+ requestOptions.setToolNames(this.defaultOptions.getToolNames());
+ requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());
+ requestOptions.setToolContext(this.defaultOptions.getToolContext());
+ }
+
+ ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks());
+
+ return new Prompt(prompt.getInstructions(), requestOptions);
+ }
+
+ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
+ ChatCompletionRequest request = createRequest(prompt, false);
+
+ ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
+ .prompt(prompt)
+ .provider(CohereApi.PROVIDER_NAME)
+ .build();
+
+ ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
+ .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
+ this.observationRegistry)
+ .observe(() -> {
+
+ ResponseEntity completionEntity = RetryUtils.execute(this.retryTemplate,
+ () -> this.cohereApi.chatCompletionEntity(request));
+
+ ChatCompletion chatCompletion = completionEntity.getBody();
+
+ if (chatCompletion == null) {
+ logger.warn("No chat completion returned for prompt: {}", prompt);
+ return new ChatResponse(List.of());
+ }
+
+ final Map metadata = Map.of("id",
+ chatCompletion.id() != null ? chatCompletion.id() : "", "role",
+ chatCompletion.message().role() != null ? chatCompletion.message().role().name() : "",
+ "finishReason",
+ chatCompletion.finishReason() != null ? chatCompletion.finishReason().name() : "");
+
+ List generations = new ArrayList<>();
+
+ if (chatCompletion.finishReason() == null) { // Just for secure
+ logger.warn("No chat completion finishReason returned for prompt: {}", prompt);
+ return new ChatResponse(List.of());
+ }
+
+ if (chatCompletion.finishReason().equals(CohereApi.ChatCompletionFinishReason.TOOL_CALL)) {
+ var generation = buildGeneration(null, chatCompletion, metadata);
+ generations.add(generation);
+ }
+ else {
+ generations = chatCompletion.message()
+ .content()
+ .stream()
+ .map(content -> buildGeneration(content, chatCompletion, metadata))
+ .toList();
+ }
+
+ DefaultUsage usage = getDefaultUsage(completionEntity.getBody().usage());
+ Usage cumulativeUsage = UsageCalculator.getCumulativeUsage(usage, previousChatResponse);
+
+ ChatResponse chatResponse = new ChatResponse(generations,
+ from(completionEntity.getBody(), cumulativeUsage));
+
+ observationContext.setResponse(chatResponse);
+
+ return chatResponse;
+ });
+
+ if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
+ var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
+ if (toolExecutionResult.returnDirect()) {
+ // Return tool execution result directly to the client.
+ return ChatResponse.builder()
+ .from(response)
+ .generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
+ .build();
+ }
+ else {
+ // remove tools actions before
+ ChatOptions chatOptions = CohereChatOptions.fromOptions2(prompt.getOptions().copy());
+ // Send the tool execution result back to the model.
+ return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), chatOptions), null);
+ }
+ }
+
+ return response;
+ }
+
+ private Generation buildGeneration(ChatCompletionMessage.MessageContent content, ChatCompletion completion,
+ Map metadata) {
+ List toolCalls = completion.message().toolCalls() == null ? List.of()
+ : completion.message()
+ .toolCalls()
+ .stream()
+ .map(toolCall -> new AssistantMessage.ToolCall(toolCall.id(), "function",
+ toolCall.function().name(), toolCall.function().arguments()))
+ .toList();
+
+ var assistantMessage = AssistantMessage.builder()
+ .content(content != null ? content.text() : "")
+ .toolCalls(toolCalls)
+ .properties(metadata)
+ .build();
+
+ String finishReason = (completion.finishReason() != null ? completion.finishReason().name() : "");
+ var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build();
+ return new Generation(assistantMessage, generationMetadata);
+ }
+
+ /**
+ * Accessible for testing.
+ */
+ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
+ List chatCompletionMessages = prompt.getInstructions()
+ .stream()
+ .map(this::convertToCohereMessage)
+ .flatMap(List::stream)
+ .toList();
+
+ var request = new ChatCompletionRequest(chatCompletionMessages, stream);
+
+ CohereChatOptions requestOptions = (CohereChatOptions) prompt.getOptions();
+ request = ModelOptionsUtils.merge(requestOptions, request, ChatCompletionRequest.class);
+
+ // Add the tool definitions to the request's tools parameter.
+ List toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions);
+ if (!CollectionUtils.isEmpty(toolDefinitions)) {
+ request = ModelOptionsUtils.merge(
+ CohereChatOptions.builder().tools(this.getFunctionTools(toolDefinitions)).build(), request,
+ ChatCompletionRequest.class);
+ }
+
+ return request;
+ }
+
+ /**
+ * Convert a Spring AI message to Cohere API message(s).
+ * @param message the Spring AI message to convert
+ * @return list of Cohere ChatCompletionMessage(s)
+ */
+ private List convertToCohereMessage(Message message) {
+ return switch (message.getMessageType()) {
+ case USER -> convertUserMessage(message);
+ case ASSISTANT -> convertAssistantMessage(message);
+ case SYSTEM -> convertSystemMessage(message);
+ case TOOL -> convertToolMessage(message);
+ };
+ }
+
+ /**
+ * Convert a USER message.
+ */
+ private List convertUserMessage(org.springframework.ai.chat.messages.Message message) {
+ Object content = message.getText();
+
+ if (message instanceof UserMessage userMessage && !CollectionUtils.isEmpty(userMessage.getMedia())) {
+ // Validate images before processing
+ CohereImageValidator.validateImages(userMessage.getMedia());
+
+ List contentList = new ArrayList<>(
+ List.of(new ChatCompletionMessage.MediaContent(message.getText())));
+
+ contentList.addAll(userMessage.getMedia().stream().map(this::mapToMediaContent).toList());
+
+ content = contentList;
+ }
+
+ return List.of(new ChatCompletionMessage(content, Role.USER));
+ }
+
+ /**
+ * Convert an ASSISTANT message.
+ */
+ private List convertAssistantMessage(org.springframework.ai.chat.messages.Message message) {
+ if (!(message instanceof AssistantMessage assistantMessage)) {
+ throw new IllegalArgumentException("Unsupported assistant message class: " + message.getClass().getName());
+ }
+
+ List toolCalls = null;
+ if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
+ toolCalls = convertToolCalls(assistantMessage.getToolCalls());
+ }
+
+ return List.of(new ChatCompletionMessage(assistantMessage.getText(), Role.ASSISTANT, toolCalls));
+ }
+
+ /**
+ * Convert tool calls.
+ */
+ private List convertToolCalls(List springToolCalls) {
+ return springToolCalls.stream().map(toolCall -> {
+ var function = new ChatCompletionMessage.ChatCompletionFunction(toolCall.name(), toolCall.arguments());
+ return new ToolCall(toolCall.id(), toolCall.type(), function, null);
+ }).toList();
+ }
+
+ /**
+ * Convert a SYSTEM message.
+ */
+ private List convertSystemMessage(org.springframework.ai.chat.messages.Message message) {
+ return List.of(new ChatCompletionMessage(message.getText(), Role.SYSTEM));
+ }
+
+ /**
+ * Convert a TOOL response message to Cohere format. Validates that all tool responses
+ * have an ID.
+ */
+ private List convertToolMessage(Message message) {
+
+ if (!(message instanceof ToolResponseMessage toolResponseMessage)) {
+ throw new IllegalArgumentException("Unsupported tool message class: " + message.getClass().getName());
+ }
+
+ return toolResponseMessage.getResponses().stream().map(toolResponse -> {
+ Assert.notNull(toolResponse.id(), "ToolResponseMessage.ToolResponse must have an id");
+ return new ChatCompletionMessage(toolResponse.responseData(), Role.TOOL, toolResponse.name(), null, null,
+ toolResponse.id());
+ }).toList();
+ }
+
+ private ChatCompletionMessage.MediaContent mapToMediaContent(Media media) {
+ CohereApi.ChatCompletionMessage.MediaContent.DetailLevel detail = this.defaultOptions != null
+ ? this.defaultOptions.getImageDetail() : null;
+ return new ChatCompletionMessage.MediaContent(new ChatCompletionMessage.MediaContent.ImageUrl(
+ this.fromMediaData(media.getMimeType(), media.getData()), detail));
+ }
+
+ private String fromMediaData(MimeType mimeType, Object mediaContentData) {
+ if (mediaContentData instanceof byte[] bytes) {
+ // Assume the bytes are an image. So, convert the bytes to a base64 encoded
+ // following the prefix pattern.
+ return String.format("data:%s;base64,%s", mimeType.toString(), Base64.getEncoder().encodeToString(bytes));
+ }
+ else if (mediaContentData instanceof String text) {
+ // Assume the text is a URLs or a base64 encoded image prefixed by the user.
+ return text;
+ }
+ else {
+ throw new IllegalArgumentException(
+ "Unsupported media data type: " + mediaContentData.getClass().getSimpleName());
+ }
+ }
+
+ private List getFunctionTools(List toolDefinitions) {
+ return toolDefinitions.stream().map(toolDefinition -> {
+ var function = new FunctionTool.Function(toolDefinition.description(), toolDefinition.name(),
+ toolDefinition.inputSchema());
+ return new FunctionTool(function);
+ }).toList();
+ }
+
+ @Override
+ public ChatOptions getDefaultOptions() {
+ return CohereChatOptions.fromOptions(this.defaultOptions);
+ }
+
+ /**
+ * Use the provided convention for reporting observation data
+ * @param observationConvention The provided convention
+ */
+ public void setObservationConvention(ChatModelObservationConvention observationConvention) {
+ Assert.notNull(observationConvention, "observationConvention cannot be null");
+ this.observationConvention = observationConvention;
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ public static final class Builder {
+
+ private CohereApi cohereApi;
+
+ private CohereChatOptions defaultOptions = CohereChatOptions.builder()
+ .temperature(0.3)
+ .topP(1.0)
+ .model(CohereApi.ChatModel.COMMAND_A_R7B.getValue())
+ .build();
+
+ private ToolCallingManager toolCallingManager;
+
+ private ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate = new DefaultToolExecutionEligibilityPredicate();
+
+ private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;
+
+ private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
+
+ private Builder() {
+ }
+
+ public Builder cohereApi(CohereApi cohereApi) {
+ this.cohereApi = cohereApi;
+ return this;
+ }
+
+ public Builder defaultOptions(CohereChatOptions defaultOptions) {
+ this.defaultOptions = defaultOptions;
+ return this;
+ }
+
+ public Builder toolCallingManager(ToolCallingManager toolCallingManager) {
+ this.toolCallingManager = toolCallingManager;
+ return this;
+ }
+
+ public Builder toolExecutionEligibilityPredicate(
+ ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) {
+ this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate;
+ return this;
+ }
+
+ public Builder retryTemplate(RetryTemplate retryTemplate) {
+ this.retryTemplate = retryTemplate;
+ return this;
+ }
+
+ public Builder observationRegistry(ObservationRegistry observationRegistry) {
+ this.observationRegistry = observationRegistry;
+ return this;
+ }
+
+ public CohereChatModel build() {
+ if (this.toolCallingManager != null) {
+ return new CohereChatModel(this.cohereApi, this.defaultOptions, this.toolCallingManager,
+ this.retryTemplate, this.observationRegistry, this.toolExecutionEligibilityPredicate);
+ }
+ return new CohereChatModel(this.cohereApi, this.defaultOptions, DEFAULT_TOOL_CALLING_MANAGER,
+ this.retryTemplate, this.observationRegistry, this.toolExecutionEligibilityPredicate);
+ }
+
+ }
+
+}
diff --git a/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/chat/CohereChatOptions.java b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/chat/CohereChatOptions.java
new file mode 100644
index 00000000000..6f9955180cf
--- /dev/null
+++ b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/chat/CohereChatOptions.java
@@ -0,0 +1,624 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.chat;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+
+import com.fasterxml.jackson.annotation.JsonIgnore;
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+import org.springframework.ai.cohere.api.CohereApi;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage.MediaContent.DetailLevel;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionRequest.ResponseFormat;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionRequest.ToolChoice;
+import org.springframework.ai.cohere.api.CohereApi.FunctionTool;
+import org.springframework.ai.model.tool.ToolCallingChatOptions;
+import org.springframework.ai.tool.ToolCallback;
+import org.springframework.lang.Nullable;
+import org.springframework.util.Assert;
+
+/**
+ * Options for the Cohere API.
+ *
+ * @author Ricken Bazolo
+ */
+@JsonInclude(JsonInclude.Include.NON_NULL)
+public class CohereChatOptions implements ToolCallingChatOptions {
+
+ /**
+ * ID of the model to use
+ */
+ private @JsonProperty("model") String model;
+
+ /**
+ * What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.8 will
+ * make the output more random, while lower values like 0.2 will make it more focused
+ * and deterministic. We generally recommend altering this or top_p but not both.
+ */
+ private @JsonProperty("temperature") Double temperature;
+
+ /**
+ * Ensures that only the most likely tokens, with total probability mass of p, are
+ * considered for generation at each step. If both k and p are enabled, p acts after
+ * k. Defaults to 0.75. min value of 0.01, max value of 0.99.
+ */
+ private @JsonProperty("p") Double p;
+
+ /**
+ * The maximum number of tokens to generate in the chat completion. The total length
+ * of input tokens and generated tokens is limited by the model's context length.
+ */
+ private @JsonProperty("max_tokens") Integer maxTokens;
+
+ /**
+ * Min value of 0.0, max value of 1.0. Used to reduce repetitiveness of generated
+ * tokens. Similar to frequency_penalty, except that this penalty is applied equally
+ * to all tokens that have already appeared, regardless of their exact frequencies.
+ */
+ private @JsonProperty("presence_penalty") Double presencePenalty;
+
+ /**
+ * Min value of 0.0, max value of 1.0. Used to reduce repetitiveness of generated
+ * tokens. Similar to frequency_penalty, except that this penalty is applied equally
+ * to all tokens that have already appeared, regardless of their exact frequencies.
+ */
+ private @JsonProperty("frequency_penalty") Double frequencyPenalty;
+
+ /**
+ * Ensures that only the top k most likely tokens are considered for generation at
+ * each step. When k is set to 0, k-sampling is disabled. Defaults to 0, min value of
+ * 0, max value of 500.
+ */
+ private @JsonProperty("k") Integer k;
+
+ /**
+ * A list of tools the model may call. Currently, only functions are supported as a
+ * tool. Use this to provide a list of functions the model may generate JSON inputs
+ * for.
+ */
+ private @JsonProperty("tools") List tools;
+
+ /**
+ * An object specifying the format that the model must output. Setting to { "type":
+ * "json_object" } enables JSON mode, which guarantees the message the model generates
+ * is valid JSON.
+ */
+ private @JsonProperty("response_format") ResponseFormat responseFormat;
+
+ /**
+ * Used to select the safety instruction inserted into the prompt. Defaults to
+ * CONTEXTUAL. When OFF is specified, the safety instruction will be omitted.
+ */
+ private @JsonProperty("safety_mode") CohereApi.SafetyMode safetyMode;
+
+ /**
+ * A list of up to 5 strings that the model will use to stop generation. If the model
+ * generates a string that matches any of the strings in the list, it will stop
+ * generating tokens and return the generated text up to that point not including the
+ * stop sequence.
+ */
+ private @JsonProperty("stop_sequences") List stopSequences;
+
+ /**
+ * If specified, the backend will make a best effort to sample tokens
+ * deterministically, such that repeated requests with the same seed and parameters
+ * should return the same result. However, determinism cannot be totally guaranteed.
+ */
+ private @JsonProperty("seed") Integer seed;
+
+ /**
+ * Defaults to false. When set to true, the log probabilities of the generated tokens
+ * will be included in the response.
+ */
+ private @JsonProperty("logprobs") Boolean logprobs;
+
+ /**
+ * Controls which (if any) function is called by the model. none means the model will
+ * not call a function and instead generates a message. auto means the model can pick
+ * between generating a message or calling a function. Specifying a particular
+ * function via {"type: "function", "function": {"name": "my_function"}} forces the
+ * model to call that function. none is the default when no functions are present.
+ * auto is the default if functions are present. Use the
+ * {@link CohereApi.ToolChoiceBuilder} to create a tool choice object.
+ */
+ private @JsonProperty("tool_choice") ToolChoice toolChoice;
+
+ private @JsonProperty("strict_tools") Boolean strictTools;
+
+ /**
+ * The level of detail for processing images. Can be "low", "high", or "auto".
+ * Defaults to "auto" if not specified. This controls the resolution at which the
+ * model views image.
+ */
+ @JsonIgnore
+ private DetailLevel imageDetail;
+
+ /**
+ * Collection of {@link ToolCallback}s to be used for tool calling in the chat
+ * completion requests.
+ */
+ @JsonIgnore
+ private List toolCallbacks = new ArrayList<>();
+
+ /**
+ * Collection of tool names to be resolved at runtime and used for tool calling in the
+ * chat completion requests.
+ */
+ @JsonIgnore
+ private Set toolNames = new HashSet<>();
+
+ /**
+ * Whether to enable the tool execution lifecycle internally in ChatModel.
+ */
+ @JsonIgnore
+ private Boolean internalToolExecutionEnabled;
+
+ @JsonIgnore
+ private Map toolContext = new HashMap<>();
+
+ public CohereApi.SafetyMode getSafetyMode() {
+ return this.safetyMode;
+ }
+
+ public void setSafetyMode(CohereApi.SafetyMode safetyMode) {
+ this.safetyMode = safetyMode;
+ }
+
+ public Integer getSeed() {
+ return this.seed;
+ }
+
+ public void setSeed(Integer seed) {
+ this.seed = seed;
+ }
+
+ public Boolean getLogprobs() {
+ return this.logprobs;
+ }
+
+ public void setLogprobs(Boolean logprobs) {
+ this.logprobs = logprobs;
+ }
+
+ public Boolean getStrictTools() {
+ return this.strictTools;
+ }
+
+ public void setStrictTools(Boolean strictTools) {
+ this.strictTools = strictTools;
+ }
+
+ public DetailLevel getImageDetail() {
+ return this.imageDetail;
+ }
+
+ public void setImageDetail(DetailLevel imageDetail) {
+ this.imageDetail = imageDetail;
+ }
+
+ public Double getP() {
+ return this.p;
+ }
+
+ public void setP(Double p) {
+ this.p = p;
+ }
+
+ @Override
+ public String getModel() {
+ return this.model;
+ }
+
+ public void setModel(String model) {
+ this.model = model;
+ }
+
+ @Override
+ public Integer getMaxTokens() {
+ return this.maxTokens;
+ }
+
+ public void setMaxTokens(Integer maxTokens) {
+ this.maxTokens = maxTokens;
+ }
+
+ public ResponseFormat getResponseFormat() {
+ return this.responseFormat;
+ }
+
+ public void setResponseFormat(ResponseFormat responseFormat) {
+ this.responseFormat = responseFormat;
+ }
+
+ @Override
+ @JsonIgnore
+ public List getStopSequences() {
+ return getStop();
+ }
+
+ @JsonIgnore
+ public void setStopSequences(List stopSequences) {
+ setStop(stopSequences);
+ }
+
+ public List getStop() {
+ return this.stopSequences;
+ }
+
+ public void setStop(List stop) {
+ this.stopSequences = stop;
+ }
+
+ public List getTools() {
+ return this.tools;
+ }
+
+ public void setTools(List tools) {
+ this.tools = tools;
+ }
+
+ public ToolChoice getToolChoice() {
+ return this.toolChoice;
+ }
+
+ public void setToolChoice(ToolChoice toolChoice) {
+ this.toolChoice = toolChoice;
+ }
+
+ @Override
+ public Double getTemperature() {
+ return this.temperature;
+ }
+
+ public void setTemperature(Double temperature) {
+ this.temperature = temperature;
+ }
+
+ @Override
+ public Double getTopP() {
+ return getP();
+ }
+
+ public void setTopP(Double topP) {
+ setP(topP);
+ }
+
+ @Override
+ public Double getFrequencyPenalty() {
+ return this.frequencyPenalty;
+ }
+
+ public void setFrequencyPenalty(Double frequencyPenalty) {
+ this.frequencyPenalty = frequencyPenalty;
+ }
+
+ @Override
+ public Double getPresencePenalty() {
+ return this.presencePenalty;
+ }
+
+ public void setPresencePenalty(Double presencePenalty) {
+ this.presencePenalty = presencePenalty;
+ }
+
+ @Override
+ public CohereChatOptions copy() {
+ return fromOptions(this);
+ }
+
+ @Override
+ @JsonIgnore
+ public List getToolCallbacks() {
+ return this.toolCallbacks;
+ }
+
+ @Override
+ @JsonIgnore
+ public void setToolCallbacks(List toolCallbacks) {
+ Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
+ Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
+ this.toolCallbacks = toolCallbacks;
+ }
+
+ @Override
+ @JsonIgnore
+ public Set getToolNames() {
+ return this.toolNames;
+ }
+
+ @Override
+ @JsonIgnore
+ public void setToolNames(Set toolNames) {
+ Assert.notNull(toolNames, "toolNames cannot be null");
+ Assert.noNullElements(toolNames, "toolNames cannot contain null elements");
+ toolNames.forEach(tool -> Assert.hasText(tool, "toolNames cannot contain empty elements"));
+ this.toolNames = toolNames;
+ }
+
+ @Override
+ @Nullable
+ @JsonIgnore
+ public Boolean getInternalToolExecutionEnabled() {
+ return this.internalToolExecutionEnabled;
+ }
+
+ @Override
+ @JsonIgnore
+ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) {
+ this.internalToolExecutionEnabled = internalToolExecutionEnabled;
+ }
+
+ @Override
+ @JsonIgnore
+ public Integer getTopK() {
+ return this.k;
+ }
+
+ public void setTopK(Integer k) {
+ this.k = k;
+ }
+
+ @Override
+ @JsonIgnore
+ public Map getToolContext() {
+ return this.toolContext;
+ }
+
+ @Override
+ @JsonIgnore
+ public void setToolContext(Map toolContext) {
+ this.toolContext = toolContext;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ CohereChatOptions that = (CohereChatOptions) o;
+ return Objects.equals(this.model, that.model) && Objects.equals(this.temperature, that.temperature)
+ && Objects.equals(this.p, that.p) && Objects.equals(this.maxTokens, that.maxTokens)
+ && Objects.equals(this.presencePenalty, that.presencePenalty)
+ && Objects.equals(this.frequencyPenalty, that.frequencyPenalty) && Objects.equals(this.k, that.k)
+ && Objects.equals(this.tools, that.tools) && Objects.equals(this.responseFormat, that.responseFormat)
+ && Objects.equals(this.safetyMode, that.safetyMode)
+ && Objects.equals(this.stopSequences, that.stopSequences) && Objects.equals(this.seed, that.seed)
+ && Objects.equals(this.logprobs, that.logprobs) && Objects.equals(this.toolChoice, that.toolChoice)
+ && Objects.equals(this.strictTools, that.strictTools)
+ && Objects.equals(this.imageDetail, that.imageDetail)
+ && Objects.equals(this.toolCallbacks, that.toolCallbacks)
+ && Objects.equals(this.toolNames, that.toolNames)
+ && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled)
+ && Objects.equals(this.toolContext, that.toolContext);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(this.model, this.temperature, this.p, this.maxTokens, this.presencePenalty,
+ this.frequencyPenalty, this.k, this.tools, this.responseFormat, this.safetyMode, this.stopSequences,
+ this.seed, this.logprobs, this.toolChoice, this.strictTools, this.imageDetail, this.toolCallbacks,
+ this.toolNames, this.internalToolExecutionEnabled, this.toolContext);
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ public static CohereChatOptions fromOptions(CohereChatOptions fromOptions) {
+ Builder builder = builder().model(fromOptions.getModel())
+ .temperature(fromOptions.getTemperature())
+ .maxTokens(fromOptions.getMaxTokens())
+ .topP(fromOptions.getTopP())
+ .frequencyPenalty(fromOptions.getFrequencyPenalty())
+ .presencePenalty(fromOptions.getPresencePenalty())
+ .topK(fromOptions.getTopK())
+ .responseFormat(fromOptions.getResponseFormat())
+ .safetyMode(fromOptions.getSafetyMode())
+ .seed(fromOptions.getSeed())
+ .logprobs(fromOptions.getLogprobs())
+ .toolChoice(fromOptions.getToolChoice())
+ .strictTools(fromOptions.getStrictTools())
+ .imageDetail(fromOptions.getImageDetail())
+ .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled());
+
+ // Create defensive copies of collections
+ if (fromOptions.getTools() != null) {
+ builder.tools(new ArrayList<>(fromOptions.getTools()));
+ }
+ if (fromOptions.getStopSequences() != null) {
+ builder.stop(new ArrayList<>(fromOptions.getStopSequences()));
+ }
+ if (fromOptions.getToolCallbacks() != null) {
+ builder.toolCallbacks(new ArrayList<>(fromOptions.getToolCallbacks()));
+ }
+ if (fromOptions.getToolNames() != null) {
+ builder.toolNames(new HashSet<>(fromOptions.getToolNames()));
+ }
+ if (fromOptions.getToolContext() != null) {
+ builder.toolContext(new HashMap<>(fromOptions.getToolContext()));
+ }
+
+ return builder.build();
+ }
+
+ public static CohereChatOptions fromOptions2(CohereChatOptions fromOptions) {
+ return builder().model(fromOptions.getModel())
+ .temperature(fromOptions.getTemperature())
+ .maxTokens(fromOptions.getMaxTokens())
+ .topP(fromOptions.getTopP())
+ .frequencyPenalty(fromOptions.getFrequencyPenalty())
+ .presencePenalty(fromOptions.getPresencePenalty())
+ .topK(fromOptions.getTopK())
+ .tools(null)
+ .responseFormat(fromOptions.getResponseFormat())
+ .safetyMode(fromOptions.getSafetyMode())
+ .stop(fromOptions.getStopSequences())
+ .seed(fromOptions.getSeed())
+ .logprobs(fromOptions.getLogprobs())
+ .toolChoice(null)
+ .strictTools(null)
+ .toolCallbacks()
+ .toolNames()
+ .internalToolExecutionEnabled(null)
+ .build();
+ }
+
+ public static class Builder {
+
+ private final CohereChatOptions options = new CohereChatOptions();
+
+ public CohereChatOptions build() {
+ return this.options;
+ }
+
+ public Builder model(String model) {
+ this.options.setModel(model);
+ return this;
+ }
+
+ public Builder model(CohereApi.ChatModel chatModel) {
+ this.options.setModel(chatModel.getName());
+ return this;
+ }
+
+ public Builder safetyMode(CohereApi.SafetyMode safetyMode) {
+ this.options.setSafetyMode(safetyMode);
+ return this;
+ }
+
+ public Builder logprobs(Boolean logprobs) {
+ this.options.setLogprobs(logprobs);
+ return this;
+ }
+
+ public Builder toolContext(Map toolContext) {
+ if (this.options.toolContext == null) {
+ this.options.toolContext = toolContext;
+ }
+ else {
+ this.options.toolContext.putAll(toolContext);
+ }
+ return this;
+ }
+
+ public Builder maxTokens(Integer maxTokens) {
+ this.options.setMaxTokens(maxTokens);
+ return this;
+ }
+
+ public Builder seed(Integer seed) {
+ this.options.setSeed(seed);
+ return this;
+ }
+
+ public Builder stop(List stop) {
+ this.options.setStop(stop);
+ return this;
+ }
+
+ public Builder frequencyPenalty(Double frequencyPenalty) {
+ this.options.frequencyPenalty = frequencyPenalty;
+ return this;
+ }
+
+ public Builder presencePenalty(Double presencePenalty) {
+ this.options.presencePenalty = presencePenalty;
+ return this;
+ }
+
+ public Builder temperature(Double temperature) {
+ this.options.setTemperature(temperature);
+ return this;
+ }
+
+ public Builder topP(Double topP) {
+ this.options.setTopP(topP);
+ return this;
+ }
+
+ public Builder topK(Integer k) {
+ this.options.setTopK(k);
+ return this;
+ }
+
+ public Builder responseFormat(ResponseFormat responseFormat) {
+ this.options.responseFormat = responseFormat;
+ return this;
+ }
+
+ public Builder tools(List tools) {
+ this.options.tools = tools;
+ return this;
+ }
+
+ public Builder strictTools(Boolean strictTools) {
+ this.options.setStrictTools(strictTools);
+ return this;
+ }
+
+ public Builder toolChoice(ToolChoice toolChoice) {
+ this.options.toolChoice = toolChoice;
+ return this;
+ }
+
+ public Builder toolCallbacks(List toolCallbacks) {
+ this.options.setToolCallbacks(toolCallbacks);
+ return this;
+ }
+
+ public Builder toolCallbacks(ToolCallback... toolCallbacks) {
+ Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
+ this.options.toolCallbacks.addAll(Arrays.asList(toolCallbacks));
+ return this;
+ }
+
+ public Builder toolNames(Set toolNames) {
+ Assert.notNull(toolNames, "toolNames cannot be null");
+ this.options.setToolNames(toolNames);
+ return this;
+ }
+
+ public Builder toolNames(String... toolNames) {
+ Assert.notNull(toolNames, "toolNames cannot be null");
+ this.options.toolNames.addAll(Set.of(toolNames));
+ return this;
+ }
+
+ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) {
+ this.options.setInternalToolExecutionEnabled(internalToolExecutionEnabled);
+ return this;
+ }
+
+ public Builder imageDetail(DetailLevel imageDetail) {
+ this.options.setImageDetail(imageDetail);
+ return this;
+ }
+
+ }
+
+}
diff --git a/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/chat/CohereImageValidator.java b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/chat/CohereImageValidator.java
new file mode 100644
index 00000000000..43ee99c300f
--- /dev/null
+++ b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/chat/CohereImageValidator.java
@@ -0,0 +1,101 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.chat;
+
+import java.util.List;
+import java.util.Set;
+
+import org.springframework.ai.content.Media;
+
+/**
+ * Validator for Cohere API image constraints.
+ *
+ * @author Ricken Bazolo
+ */
+public final class CohereImageValidator {
+
+ private static final int MAX_IMAGES_PER_REQUEST = 20;
+
+ private static final long MAX_TOTAL_IMAGE_SIZE_BYTES = 20 * 1024 * 1024;
+
+ private static final Set SUPPORTED_IMAGE_FORMATS = Set.of("image/jpeg", "image/png", "image/webp",
+ "image/gif");
+
+ private CohereImageValidator() {
+ }
+
+ public static void validateImages(List mediaList) {
+ if (mediaList == null || mediaList.isEmpty()) {
+ return;
+ }
+
+ validateImageCount(mediaList);
+ validateImageFormats(mediaList);
+ validateTotalImageSize(mediaList);
+ }
+
+ private static void validateImageCount(List mediaList) {
+ if (mediaList.size() > MAX_IMAGES_PER_REQUEST) {
+ throw new IllegalArgumentException(
+ String.format("Cohere API supports maximum %d images per request, found: %d",
+ MAX_IMAGES_PER_REQUEST, mediaList.size()));
+ }
+ }
+
+ private static void validateImageFormats(List mediaList) {
+ for (Media media : mediaList) {
+ var mimeType = media.getMimeType().toString();
+ if (!SUPPORTED_IMAGE_FORMATS.contains(mimeType)) {
+ throw new IllegalArgumentException(String
+ .format("Unsupported image format: %s. Supported formats: JPEG, PNG, WebP, GIF", mimeType));
+ }
+ }
+ }
+
+ private static void validateTotalImageSize(List mediaList) {
+ long totalSize = 0;
+
+ for (Media media : mediaList) {
+ long mediaSize = calculateMediaSize(media);
+ totalSize += mediaSize;
+ }
+
+ if (totalSize > MAX_TOTAL_IMAGE_SIZE_BYTES) {
+ long totalSizeMB = totalSize / (1024 * 1024);
+ throw new IllegalArgumentException(String.format("Total image size exceeds 20MB limit: %dMB", totalSizeMB));
+ }
+ }
+
+ private static long calculateMediaSize(Media media) {
+ var data = media.getData();
+
+ if (data instanceof byte[] bytes) {
+ return bytes.length;
+ }
+
+ if (data instanceof String text) {
+ if (text.startsWith("data:")) {
+ var base64Data = text.substring(text.indexOf(",") + 1);
+ return (long) (base64Data.length() * 0.75);
+ }
+ return 0;
+ }
+
+ return 0;
+ }
+
+}
diff --git a/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereEmbeddingModel.java b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereEmbeddingModel.java
new file mode 100644
index 00000000000..d0d526a0a0a
--- /dev/null
+++ b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereEmbeddingModel.java
@@ -0,0 +1,249 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.embedding;
+
+import java.util.List;
+import java.util.Map;
+
+import io.micrometer.observation.ObservationRegistry;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.springframework.ai.cohere.api.CohereApi;
+import org.springframework.ai.document.Document;
+import org.springframework.ai.document.MetadataMode;
+import org.springframework.ai.embedding.AbstractEmbeddingModel;
+import org.springframework.ai.embedding.Embedding;
+import org.springframework.ai.embedding.EmbeddingOptions;
+import org.springframework.ai.embedding.EmbeddingRequest;
+import org.springframework.ai.embedding.EmbeddingResponse;
+import org.springframework.ai.embedding.EmbeddingResponseMetadata;
+import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
+import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext;
+import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
+import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation;
+import org.springframework.ai.model.ModelOptionsUtils;
+import org.springframework.ai.retry.RetryUtils;
+import org.springframework.core.retry.RetryTemplate;
+import org.springframework.util.Assert;
+
+/**
+ * Provides the Cohere Embedding Model.
+ *
+ * @author Ricken Bazolo
+ * @see AbstractEmbeddingModel
+ */
+public class CohereEmbeddingModel extends AbstractEmbeddingModel {
+
+ private static final Logger logger = LoggerFactory.getLogger(CohereEmbeddingModel.class);
+
+ /**
+ * Known embedding dimensions for Cohere models. Maps model names to their respective
+ * embedding vector dimensions. This allows the dimensions() method to return the
+ * correct value without making an API call.
+ */
+ private static final Map KNOWN_EMBEDDING_DIMENSIONS = Map.of(
+ CohereApi.EmbeddingModel.EMBED_MULTILINGUAL_V3.getValue(), 1024,
+ CohereApi.EmbeddingModel.EMBED_ENGLISH_V3.getValue(), 1024,
+ CohereApi.EmbeddingModel.EMBED_MULTILINGUAL_LIGHT_V3.getValue(), 384,
+ CohereApi.EmbeddingModel.EMBED_ENGLISH_LIGHT_V3.getValue(), 384,
+ CohereApi.EmbeddingModel.EMBED_V4.getValue(), 1536);
+
+ private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();
+
+ private final CohereEmbeddingOptions defaultOptions;
+
+ private final MetadataMode metadataMode;
+
+ private final CohereApi cohereApi;
+
+ private final RetryTemplate retryTemplate;
+
+ /**
+ * Observation registry used for instrumentation.
+ */
+ private final ObservationRegistry observationRegistry;
+
+ /**
+ * Conventions to use for generating observations.
+ */
+ private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
+
+ public CohereEmbeddingModel(CohereApi cohereApi, MetadataMode metadataMode, CohereEmbeddingOptions options,
+ RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
+ Assert.notNull(cohereApi, "cohereApi must not be null");
+ Assert.notNull(metadataMode, "metadataMode must not be null");
+ Assert.notNull(options, "options must not be null");
+ Assert.notNull(retryTemplate, "retryTemplate must not be null");
+ Assert.notNull(observationRegistry, "observationRegistry must not be null");
+
+ this.cohereApi = cohereApi;
+ this.metadataMode = metadataMode;
+ this.defaultOptions = options;
+ this.retryTemplate = retryTemplate;
+ this.observationRegistry = observationRegistry;
+ }
+
+ @Override
+ public EmbeddingResponse call(EmbeddingRequest request) {
+
+ var apiRequest = createRequest(request);
+
+ var observationContext = EmbeddingModelObservationContext.builder()
+ .embeddingRequest(request)
+ .provider(CohereApi.PROVIDER_NAME)
+ .build();
+
+ return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION
+ .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
+ this.observationRegistry)
+ .observe(() -> {
+ var apiEmbeddingResponse = RetryUtils.execute(this.retryTemplate,
+ () -> this.cohereApi.embeddings(apiRequest).getBody());
+
+ if (apiEmbeddingResponse == null) {
+ logger.warn("No embeddings returned for request: {}", request);
+ return new EmbeddingResponse(List.of());
+ }
+
+ var metadata = generateResponseMetadata(apiEmbeddingResponse.responseType());
+
+ // Extract float embeddings from response
+ List floatEmbeddings = apiEmbeddingResponse.getFloatEmbeddings();
+
+ // Map to Spring AI Embedding objects with proper indexing
+ List embeddings = new java.util.ArrayList<>();
+ for (int i = 0; i < floatEmbeddings.size(); i++) {
+ embeddings.add(new Embedding(floatEmbeddings.get(i), i));
+ }
+
+ var embeddingResponse = new EmbeddingResponse(embeddings, metadata);
+
+ observationContext.setResponse(embeddingResponse);
+
+ return embeddingResponse;
+ });
+ }
+
+ @Override
+ public float[] embed(Document document) {
+ Assert.notNull(document, "Document must not be null");
+ return this.embed(document.getFormattedContent(this.metadataMode));
+ }
+
+ private EmbeddingResponseMetadata generateResponseMetadata(String embeddingType) {
+ return new EmbeddingResponseMetadata(embeddingType, null);
+ }
+
+ /**
+ * Use the provided convention for reporting observation data.
+ * @param observationConvention The provided convention
+ */
+ public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) {
+ Assert.notNull(observationConvention, "observationConvention cannot be null");
+ this.observationConvention = observationConvention;
+ }
+
+ private CohereApi.EmbeddingRequest createRequest(EmbeddingRequest request) {
+ CohereEmbeddingOptions options = mergeOptions(request.getOptions(), this.defaultOptions);
+
+ return CohereApi.EmbeddingRequest.builder()
+ .model(options.getModel())
+ .inputType(options.getInputType())
+ .embeddingTypes(options.getEmbeddingTypes())
+ .texts(request.getInstructions())
+ .truncate(options.getTruncate())
+ .build();
+ }
+
+ private CohereEmbeddingOptions mergeOptions(EmbeddingOptions requestOptions,
+ CohereEmbeddingOptions defaultOptions) {
+ CohereEmbeddingOptions options = (requestOptions != null)
+ ? ModelOptionsUtils.merge(requestOptions, defaultOptions, CohereEmbeddingOptions.class)
+ : defaultOptions;
+
+ if (options == null) {
+ throw new IllegalArgumentException("Embedding options must not be null");
+ }
+
+ return options;
+ }
+
+ private CohereEmbeddingOptions buildRequestOptions(EmbeddingRequest request) {
+ return mergeOptions(request.getOptions(), this.defaultOptions);
+ }
+
+ @Override
+ public int dimensions() {
+ String model = this.defaultOptions.getModel();
+ if (model == null) {
+ return KNOWN_EMBEDDING_DIMENSIONS.get(CohereApi.EmbeddingModel.EMBED_V4.getValue());
+ }
+ return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(model, 1024);
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ public static final class Builder {
+
+ private CohereApi cohereApi;
+
+ private MetadataMode metadataMode = MetadataMode.EMBED;
+
+ private CohereEmbeddingOptions options = CohereEmbeddingOptions.builder()
+ .model(CohereApi.EmbeddingModel.EMBED_MULTILINGUAL_LIGHT_V3.getValue())
+ .build();
+
+ private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;
+
+ private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
+
+ public Builder cohereApi(CohereApi cohereApi) {
+ this.cohereApi = cohereApi;
+ return this;
+ }
+
+ public Builder metadataMode(MetadataMode metadataMode) {
+ this.metadataMode = metadataMode;
+ return this;
+ }
+
+ public Builder options(CohereEmbeddingOptions options) {
+ this.options = options;
+ return this;
+ }
+
+ public Builder retryTemplate(RetryTemplate retryTemplate) {
+ this.retryTemplate = retryTemplate;
+ return this;
+ }
+
+ public Builder observationRegistry(ObservationRegistry observationRegistry) {
+ this.observationRegistry = observationRegistry;
+ return this;
+ }
+
+ public CohereEmbeddingModel build() {
+ return new CohereEmbeddingModel(this.cohereApi, this.metadataMode, this.options, this.retryTemplate,
+ this.observationRegistry);
+ }
+
+ }
+
+}
diff --git a/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereEmbeddingOptions.java b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereEmbeddingOptions.java
new file mode 100644
index 00000000000..6b801ea26bb
--- /dev/null
+++ b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereEmbeddingOptions.java
@@ -0,0 +1,191 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.embedding;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+import org.springframework.ai.cohere.api.CohereApi;
+import org.springframework.ai.cohere.api.CohereApi.EmbeddingRequest.InputType;
+import org.springframework.ai.cohere.api.CohereApi.EmbeddingRequest.Truncate;
+import org.springframework.ai.cohere.api.CohereApi.EmbeddingType;
+import org.springframework.ai.embedding.EmbeddingOptions;
+
+/**
+ * Options for the Cohere Embedding API.
+ *
+ * @author Ricken Bazolo
+ */
+@JsonInclude(JsonInclude.Include.NON_NULL)
+public final class CohereEmbeddingOptions implements EmbeddingOptions {
+
+ /**
+ * ID of the model to use.
+ */
+ @JsonProperty("model")
+ private String model;
+
+ /**
+ * The type of input (search_document, search_query, classification, clustering).
+ */
+ @JsonProperty("input_type")
+ private InputType inputType;
+
+ /**
+ * The types of embeddings to return (float, int8, uint8, binary, ubinary).
+ */
+ @JsonProperty("embedding_types")
+ private List embeddingTypes = new ArrayList<>();
+
+ /**
+ * How to handle inputs longer than the maximum token length (NONE, START, END).
+ */
+ @JsonProperty("truncate")
+ private Truncate truncate;
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ public static CohereEmbeddingOptions fromOptions(CohereEmbeddingOptions fromOptions) {
+ return builder().model(fromOptions.getModel())
+ .inputType(fromOptions.getInputType())
+ .embeddingTypes(
+ fromOptions.getEmbeddingTypes() != null ? new ArrayList<>(fromOptions.getEmbeddingTypes()) : null)
+ .truncate(fromOptions.getTruncate())
+ .build();
+ }
+
+ private CohereEmbeddingOptions() {
+ this.embeddingTypes.add(EmbeddingType.FLOAT);
+ this.inputType = InputType.CLASSIFICATION;
+ }
+
+ @Override
+ public String getModel() {
+ return this.model;
+ }
+
+ public void setModel(String model) {
+ this.model = model;
+ }
+
+ public InputType getInputType() {
+ return this.inputType;
+ }
+
+ public void setInputType(InputType inputType) {
+ this.inputType = inputType;
+ }
+
+ public List getEmbeddingTypes() {
+ return this.embeddingTypes;
+ }
+
+ public void setEmbeddingTypes(List embeddingTypes) {
+ this.embeddingTypes = embeddingTypes;
+ }
+
+ public Truncate getTruncate() {
+ return this.truncate;
+ }
+
+ public void setTruncate(Truncate truncate) {
+ this.truncate = truncate;
+ }
+
+ @Override
+ public Integer getDimensions() {
+ // Cohere embeddings have fixed dimensions based on model
+ // embed-multilingual-v3 and embed-english-v3: 1024
+ // This should be handled by the model implementation
+ return null;
+ }
+
+ public CohereEmbeddingOptions copy() {
+ return fromOptions(this);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(this.model, this.inputType, this.embeddingTypes, this.truncate);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ CohereEmbeddingOptions that = (CohereEmbeddingOptions) o;
+
+ return Objects.equals(this.model, that.model) && Objects.equals(this.inputType, that.inputType)
+ && Objects.equals(this.embeddingTypes, that.embeddingTypes)
+ && Objects.equals(this.truncate, that.truncate);
+ }
+
+ public static final class Builder {
+
+ private CohereEmbeddingOptions options;
+
+ public Builder() {
+ this.options = new CohereEmbeddingOptions();
+ }
+
+ public Builder(CohereEmbeddingOptions options) {
+ this.options = options;
+ }
+
+ public Builder model(String model) {
+ this.options.model = model;
+ return this;
+ }
+
+ public Builder model(CohereApi.EmbeddingModel model) {
+ this.options.model = model.getValue();
+ return this;
+ }
+
+ public Builder inputType(InputType inputType) {
+ this.options.inputType = inputType;
+ return this;
+ }
+
+ public Builder embeddingTypes(List embeddingTypes) {
+ this.options.embeddingTypes = embeddingTypes;
+ return this;
+ }
+
+ public Builder truncate(Truncate truncate) {
+ this.options.truncate = truncate;
+ return this;
+ }
+
+ public CohereEmbeddingOptions build() {
+ return this.options;
+ }
+
+ }
+
+}
diff --git a/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereEmbeddingUtils.java b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereEmbeddingUtils.java
new file mode 100644
index 00000000000..997fb4bd89e
--- /dev/null
+++ b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereEmbeddingUtils.java
@@ -0,0 +1,88 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.embedding;
+
+import java.util.Base64;
+import java.util.List;
+
+import org.springframework.ai.content.Media;
+import org.springframework.util.Assert;
+import org.springframework.util.MimeType;
+import org.springframework.util.MimeTypeUtils;
+
+/**
+ * Utility class for Cohere embedding operations.
+ *
+ * @author Ricken Bazolo
+ */
+public final class CohereEmbeddingUtils {
+
+ private static final List SUPPORTED_IMAGE_TYPES = List.of(MimeTypeUtils.IMAGE_JPEG,
+ MimeTypeUtils.IMAGE_PNG, MimeTypeUtils.parseMimeType("image/webp"), MimeTypeUtils.IMAGE_GIF);
+
+ private static final long MAX_IMAGE_SIZE_BYTES = 5 * 1024 * 1024;
+
+ private CohereEmbeddingUtils() {
+ }
+
+ public static String mediaToDataUri(Media media) {
+ Assert.notNull(media, "Media cannot be null");
+ validateImageMedia(media);
+
+ byte[] imageData = getImageBytes(media);
+ validateImageSize(imageData);
+
+ String base64Data = Base64.getEncoder().encodeToString(imageData);
+ String mimeType = media.getMimeType().toString();
+
+ return String.format("data:%s;base64,%s", mimeType, base64Data);
+ }
+
+ private static void validateImageMedia(Media media) {
+ MimeType mimeType = media.getMimeType();
+ boolean isSupported = SUPPORTED_IMAGE_TYPES.stream()
+ .anyMatch(supported -> mimeType.isCompatibleWith(supported));
+
+ if (!isSupported) {
+ throw new IllegalArgumentException("Unsupported image MIME type: " + mimeType
+ + ". Supported types: image/jpeg, image/png, image/webp, image/gif");
+ }
+ }
+
+ private static byte[] getImageBytes(Media media) {
+ Object data = media.getData();
+
+ if (data instanceof byte[] bytes) {
+ return bytes;
+ }
+ else if (data instanceof String base64String) {
+ return Base64.getDecoder().decode(base64String);
+ }
+ else {
+ throw new IllegalArgumentException("Media data must be byte[] or base64 String");
+ }
+ }
+
+ private static void validateImageSize(byte[] imageData) {
+ if (imageData.length > MAX_IMAGE_SIZE_BYTES) {
+ throw new IllegalArgumentException(
+ String.format("Image size (%d bytes) exceeds maximum allowed size (%d bytes)", imageData.length,
+ MAX_IMAGE_SIZE_BYTES));
+ }
+ }
+
+}
diff --git a/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereMultimodalEmbeddingModel.java b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereMultimodalEmbeddingModel.java
new file mode 100644
index 00000000000..15ea8b5b17e
--- /dev/null
+++ b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereMultimodalEmbeddingModel.java
@@ -0,0 +1,208 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.embedding;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+import io.micrometer.observation.ObservationRegistry;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.springframework.ai.cohere.api.CohereApi;
+import org.springframework.ai.document.Document;
+import org.springframework.ai.embedding.DocumentEmbeddingModel;
+import org.springframework.ai.embedding.DocumentEmbeddingRequest;
+import org.springframework.ai.embedding.Embedding;
+import org.springframework.ai.embedding.EmbeddingResponse;
+import org.springframework.ai.embedding.EmbeddingResponseMetadata;
+import org.springframework.ai.model.ModelOptionsUtils;
+import org.springframework.ai.retry.RetryUtils;
+import org.springframework.core.retry.RetryTemplate;
+import org.springframework.util.Assert;
+import org.springframework.util.StringUtils;
+
+/**
+ * Implementation of the Cohere Multimodal Embedding Model.
+ *
+ * @author Ricken Bazolo
+ */
+public class CohereMultimodalEmbeddingModel implements DocumentEmbeddingModel {
+
+ private static final Logger logger = LoggerFactory.getLogger(CohereMultimodalEmbeddingModel.class);
+
+ private static final Map KNOWN_EMBEDDING_DIMENSIONS = Map.of(
+ CohereApi.EmbeddingModel.EMBED_MULTILINGUAL_V3.getValue(), 1024,
+ CohereApi.EmbeddingModel.EMBED_ENGLISH_V3.getValue(), 1024,
+ CohereApi.EmbeddingModel.EMBED_MULTILINGUAL_LIGHT_V3.getValue(), 384,
+ CohereApi.EmbeddingModel.EMBED_ENGLISH_LIGHT_V3.getValue(), 384,
+ CohereApi.EmbeddingModel.EMBED_V4.getValue(), 1536);
+
+ private final CohereMultimodalEmbeddingOptions defaultOptions;
+
+ private final CohereApi cohereApi;
+
+ private final RetryTemplate retryTemplate;
+
+ private final ObservationRegistry observationRegistry;
+
+ public CohereMultimodalEmbeddingModel(CohereApi cohereApi, CohereMultimodalEmbeddingOptions options,
+ RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
+ Assert.notNull(cohereApi, "cohereApi must not be null");
+ Assert.notNull(options, "options must not be null");
+ Assert.notNull(retryTemplate, "retryTemplate must not be null");
+ Assert.notNull(observationRegistry, "observationRegistry must not be null");
+
+ this.cohereApi = cohereApi;
+ this.defaultOptions = options;
+ this.retryTemplate = retryTemplate;
+ this.observationRegistry = observationRegistry;
+ }
+
+ @Override
+ public EmbeddingResponse call(DocumentEmbeddingRequest request) {
+ CohereMultimodalEmbeddingOptions mergedOptions = mergeOptions(request.getOptions(), this.defaultOptions);
+
+ List allEmbeddings = new ArrayList<>();
+ EmbeddingResponseMetadata lastMetadata = null;
+
+ for (Document document : request.getInstructions()) {
+ CohereApi.EmbeddingRequest> apiRequest;
+
+ if (document.getMedia() != null) {
+ apiRequest = createImageRequest(document, mergedOptions);
+ }
+ else if (StringUtils.hasText(document.getText())) {
+ apiRequest = createTextRequest(document, mergedOptions);
+ }
+ else {
+ logger.warn("Document {} has no text or media content", document.getId());
+ continue;
+ }
+
+ var apiResponse = RetryUtils.execute(this.retryTemplate,
+ () -> this.cohereApi.embeddings(apiRequest).getBody());
+
+ if (apiResponse != null) {
+ List floatEmbeddings = apiResponse.getFloatEmbeddings();
+ for (int i = 0; i < floatEmbeddings.size(); i++) {
+ allEmbeddings.add(new Embedding(floatEmbeddings.get(i), allEmbeddings.size()));
+ }
+ lastMetadata = generateResponseMetadata(apiResponse.responseType());
+ }
+ }
+
+ return new EmbeddingResponse(allEmbeddings, lastMetadata);
+ }
+
+ @Override
+ public int dimensions() {
+ String model = this.defaultOptions.getModel();
+ if (model == null) {
+ return KNOWN_EMBEDDING_DIMENSIONS.get(CohereApi.EmbeddingModel.EMBED_V4.getValue());
+ }
+ return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(model, 1024);
+ }
+
+ private CohereApi.EmbeddingRequest createTextRequest(Document document,
+ CohereMultimodalEmbeddingOptions options) {
+ return CohereApi.EmbeddingRequest.builder()
+ .model(options.getModel())
+ .inputType(CohereApi.EmbeddingRequest.InputType.CLASSIFICATION)
+ .embeddingTypes(options.getEmbeddingTypes())
+ .texts(List.of(document.getText()))
+ .truncate(options.getTruncate())
+ .build();
+ }
+
+ private CohereApi.EmbeddingRequest createImageRequest(Document document,
+ CohereMultimodalEmbeddingOptions options) {
+
+ String dataUri = CohereEmbeddingUtils.mediaToDataUri(document.getMedia());
+
+ return CohereApi.EmbeddingRequest.builder()
+ .model(options.getModel())
+ .inputType(CohereApi.EmbeddingRequest.InputType.IMAGE)
+ .embeddingTypes(options.getEmbeddingTypes())
+ .images(List.of(dataUri))
+ .truncate(options.getTruncate())
+ .build();
+ }
+
+ private CohereMultimodalEmbeddingOptions mergeOptions(
+ org.springframework.ai.embedding.EmbeddingOptions requestOptions,
+ CohereMultimodalEmbeddingOptions defaultOptions) {
+ CohereMultimodalEmbeddingOptions options = (requestOptions != null)
+ ? ModelOptionsUtils.merge(requestOptions, defaultOptions, CohereMultimodalEmbeddingOptions.class)
+ : defaultOptions;
+
+ if (options == null) {
+ throw new IllegalArgumentException("Embedding options must not be null");
+ }
+
+ return options;
+ }
+
+ private EmbeddingResponseMetadata generateResponseMetadata(String embeddingType) {
+ return new EmbeddingResponseMetadata(embeddingType, null);
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ public static final class Builder {
+
+ private CohereApi cohereApi;
+
+ private CohereMultimodalEmbeddingOptions options = CohereMultimodalEmbeddingOptions.builder()
+ .model(CohereApi.EmbeddingModel.EMBED_V4.getValue())
+ .build();
+
+ private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;
+
+ private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
+
+ public Builder cohereApi(CohereApi cohereApi) {
+ this.cohereApi = cohereApi;
+ return this;
+ }
+
+ public Builder options(CohereMultimodalEmbeddingOptions options) {
+ this.options = options;
+ return this;
+ }
+
+ public Builder retryTemplate(RetryTemplate retryTemplate) {
+ this.retryTemplate = retryTemplate;
+ return this;
+ }
+
+ public Builder observationRegistry(ObservationRegistry observationRegistry) {
+ this.observationRegistry = observationRegistry;
+ return this;
+ }
+
+ public CohereMultimodalEmbeddingModel build() {
+ return new CohereMultimodalEmbeddingModel(this.cohereApi, this.options, this.retryTemplate,
+ this.observationRegistry);
+ }
+
+ }
+
+}
diff --git a/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereMultimodalEmbeddingOptions.java b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereMultimodalEmbeddingOptions.java
new file mode 100644
index 00000000000..40ad1d3ed41
--- /dev/null
+++ b/models/spring-ai-cohere/src/main/java/org/springframework/ai/cohere/embedding/CohereMultimodalEmbeddingOptions.java
@@ -0,0 +1,177 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.embedding;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+import org.springframework.ai.cohere.api.CohereApi;
+import org.springframework.ai.cohere.api.CohereApi.EmbeddingRequest.InputType;
+import org.springframework.ai.cohere.api.CohereApi.EmbeddingRequest.Truncate;
+import org.springframework.ai.cohere.api.CohereApi.EmbeddingType;
+import org.springframework.ai.embedding.EmbeddingOptions;
+
+/**
+ * Options for the Cohere Multimodal Embedding API.
+ *
+ * @author Ricken Bazolo
+ */
+@JsonInclude(JsonInclude.Include.NON_NULL)
+public final class CohereMultimodalEmbeddingOptions implements EmbeddingOptions {
+
+ @JsonProperty("model")
+ private String model;
+
+ @JsonProperty("input_type")
+ private InputType inputType;
+
+ @JsonProperty("embedding_types")
+ private List embeddingTypes = new ArrayList<>();
+
+ @JsonProperty("truncate")
+ private Truncate truncate;
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ public static CohereMultimodalEmbeddingOptions fromOptions(CohereMultimodalEmbeddingOptions fromOptions) {
+ return builder().model(fromOptions.getModel())
+ .inputType(fromOptions.getInputType())
+ .embeddingTypes(
+ fromOptions.getEmbeddingTypes() != null ? new ArrayList<>(fromOptions.getEmbeddingTypes()) : null)
+ .truncate(fromOptions.getTruncate())
+ .build();
+ }
+
+ private CohereMultimodalEmbeddingOptions() {
+ this.embeddingTypes.add(EmbeddingType.FLOAT);
+ this.inputType = InputType.CLASSIFICATION;
+ this.model = CohereApi.EmbeddingModel.EMBED_V4.getValue();
+ }
+
+ @Override
+ public String getModel() {
+ return this.model;
+ }
+
+ public void setModel(String model) {
+ this.model = model;
+ }
+
+ public InputType getInputType() {
+ return this.inputType;
+ }
+
+ public void setInputType(InputType inputType) {
+ this.inputType = inputType;
+ }
+
+ public List getEmbeddingTypes() {
+ return this.embeddingTypes;
+ }
+
+ public void setEmbeddingTypes(List embeddingTypes) {
+ this.embeddingTypes = embeddingTypes;
+ }
+
+ public Truncate getTruncate() {
+ return this.truncate;
+ }
+
+ public void setTruncate(Truncate truncate) {
+ this.truncate = truncate;
+ }
+
+ @Override
+ public Integer getDimensions() {
+ return null;
+ }
+
+ public CohereMultimodalEmbeddingOptions copy() {
+ return fromOptions(this);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(this.model, this.inputType, this.embeddingTypes, this.truncate);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ CohereMultimodalEmbeddingOptions that = (CohereMultimodalEmbeddingOptions) o;
+
+ return Objects.equals(this.model, that.model) && Objects.equals(this.inputType, that.inputType)
+ && Objects.equals(this.embeddingTypes, that.embeddingTypes)
+ && Objects.equals(this.truncate, that.truncate);
+ }
+
+ public static final class Builder {
+
+ private CohereMultimodalEmbeddingOptions options;
+
+ public Builder() {
+ this.options = new CohereMultimodalEmbeddingOptions();
+ }
+
+ public Builder(CohereMultimodalEmbeddingOptions options) {
+ this.options = options;
+ }
+
+ public Builder model(String model) {
+ this.options.model = model;
+ return this;
+ }
+
+ public Builder model(CohereApi.EmbeddingModel model) {
+ this.options.model = model.getValue();
+ return this;
+ }
+
+ public Builder inputType(InputType inputType) {
+ this.options.inputType = inputType;
+ return this;
+ }
+
+ public Builder embeddingTypes(List embeddingTypes) {
+ this.options.embeddingTypes = embeddingTypes;
+ return this;
+ }
+
+ public Builder truncate(Truncate truncate) {
+ this.options.truncate = truncate;
+ return this;
+ }
+
+ public CohereMultimodalEmbeddingOptions build() {
+ return this.options;
+ }
+
+ }
+
+}
diff --git a/models/spring-ai-cohere/src/main/resources/META-INF/spring/aot.factories b/models/spring-ai-cohere/src/main/resources/META-INF/spring/aot.factories
new file mode 100644
index 00000000000..e6ba5bc93af
--- /dev/null
+++ b/models/spring-ai-cohere/src/main/resources/META-INF/spring/aot.factories
@@ -0,0 +1,2 @@
+org.springframework.aot.hint.RuntimeHintsRegistrar=\
+ org.springframework.ai.cohere.aot.CohereRuntimeHints
diff --git a/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/CohereRetryTests.java b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/CohereRetryTests.java
new file mode 100644
index 00000000000..89d4cdb2b84
--- /dev/null
+++ b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/CohereRetryTests.java
@@ -0,0 +1,181 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere;
+
+import java.util.List;
+import java.util.Optional;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+
+import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.cohere.api.CohereApi;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletion;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionFinishReason;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage.Role;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionRequest;
+import org.springframework.ai.cohere.api.CohereApi.EmbeddingRequest;
+import org.springframework.ai.cohere.api.CohereApi.EmbeddingResponse;
+import org.springframework.ai.cohere.api.CohereApi.Usage;
+import org.springframework.ai.cohere.chat.CohereChatModel;
+import org.springframework.ai.cohere.chat.CohereChatOptions;
+import org.springframework.ai.cohere.embedding.CohereEmbeddingModel;
+import org.springframework.ai.retry.RetryUtils;
+import org.springframework.ai.retry.TransientAiException;
+import org.springframework.core.retry.RetryListener;
+import org.springframework.core.retry.RetryPolicy;
+import org.springframework.core.retry.RetryTemplate;
+import org.springframework.core.retry.Retryable;
+import org.springframework.http.ResponseEntity;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.mockito.ArgumentMatchers.isA;
+import static org.mockito.BDDMockito.given;
+
+/**
+ * @author Ricken Bazolo
+ */
+@SuppressWarnings("unchecked")
+@ExtendWith(MockitoExtension.class)
+public class CohereRetryTests {
+
+ private TestRetryListener retryListener;
+
+ private RetryTemplate retryTemplate;
+
+ private @Mock CohereApi cohereApi;
+
+ private CohereChatModel chatModel;
+
+ private CohereEmbeddingModel embeddingModel;
+
+ @BeforeEach
+ public void beforeEach() {
+ this.retryTemplate = RetryUtils.SHORT_RETRY_TEMPLATE;
+ this.retryListener = new TestRetryListener();
+ this.retryTemplate.setRetryListener(this.retryListener);
+
+ this.chatModel = CohereChatModel.builder()
+ .cohereApi(this.cohereApi)
+ .defaultOptions(CohereChatOptions.builder()
+ .temperature(0.7)
+ .topP(1.0)
+ .model(CohereApi.ChatModel.COMMAND_A_R7B.getValue())
+ .build())
+ .retryTemplate(this.retryTemplate)
+ .build();
+ this.embeddingModel = CohereEmbeddingModel.builder()
+ .cohereApi(this.cohereApi)
+ .retryTemplate(this.retryTemplate)
+ .build();
+ }
+
+ @Test
+ public void cohereChatTransientError() {
+ var message = new ChatCompletionMessage.Provider(
+ List.of(new ChatCompletionMessage.MessageContent("text", "Response", null)), Role.ASSISTANT, null, null,
+ null);
+
+ ChatCompletion expectedChatCompletion = new ChatCompletion("id", ChatCompletionFinishReason.COMPLETE, message,
+ null, new Usage(null, new Usage.Tokens(10, 20), 10));
+
+ given(this.cohereApi.chatCompletionEntity(isA(ChatCompletionRequest.class)))
+ .willThrow(new TransientAiException("Transient Error 1"))
+ .willThrow(new TransientAiException("Transient Error 2"))
+ .willReturn(ResponseEntity.of(Optional.of(expectedChatCompletion)));
+
+ var result = this.chatModel.call(new Prompt("text"));
+
+ assertThat(result).isNotNull();
+ assertThat(result.getResult().getOutput().getText()).isEqualTo("Response");
+ assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1);
+ assertThat(this.retryListener.retryCount).isEqualTo(2);
+ }
+
+ @Test
+ public void cohereChatNonTransientError() {
+ given(this.cohereApi.chatCompletionEntity(isA(ChatCompletionRequest.class)))
+ .willThrow(new RuntimeException("Non Transient Error"));
+ assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt("text")));
+ }
+
+ @Test
+ public void cohereEmbeddingTransientError() {
+ List> embeddingsList = List.of(List.of(9.9, 8.8), List.of(7.7, 6.6));
+
+ EmbeddingResponse expectedEmbeddings = new EmbeddingResponse("id", embeddingsList, List.of("text1", "text2"),
+ "embeddings_floats");
+
+ given(this.cohereApi.embeddings(isA(EmbeddingRequest.class)))
+ .willThrow(new TransientAiException("Transient Error 1"))
+ .willThrow(new TransientAiException("Transient Error 2"))
+ .willReturn(ResponseEntity.of(Optional.of(expectedEmbeddings)));
+
+ var result = this.embeddingModel
+ .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null));
+
+ assertThat(result).isNotNull();
+ assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1);
+ assertThat(this.retryListener.retryCount).isEqualTo(2);
+ }
+
+ @Test
+ public void cohereEmbeddingNonTransientError() {
+ given(this.cohereApi.embeddings(isA(EmbeddingRequest.class)))
+ .willThrow(new RuntimeException("Non Transient Error"));
+ assertThrows(RuntimeException.class, () -> this.embeddingModel
+ .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)));
+ }
+
+ @Test
+ public void cohereChatMixedTransientAndNonTransientErrors() {
+ given(this.cohereApi.chatCompletionEntity(isA(ChatCompletionRequest.class)))
+ .willThrow(new TransientAiException("Transient Error"))
+ .willThrow(new RuntimeException("Non Transient Error"));
+
+ // Should fail immediately on non-transient error, no further retries
+ assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt("text")));
+
+ // Should have 1 retry attempt before hitting non-transient error
+ assertThat(this.retryListener.retryCount).isEqualTo(1);
+ }
+
+ private static class TestRetryListener implements RetryListener {
+
+ int retryCount = 0;
+
+ int onSuccessRetryCount = 0;
+
+ @Override
+ public void onRetrySuccess(final RetryPolicy retryPolicy, final Retryable> retryable, final Object result) {
+ // Count successful retries - we increment when we succeed after a failure
+ this.onSuccessRetryCount++;
+ }
+
+ @Override
+ public void beforeRetry(RetryPolicy retryPolicy, Retryable> retryable) {
+ this.retryCount++;
+ }
+
+ }
+
+}
diff --git a/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/CohereTestConfiguration.java b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/CohereTestConfiguration.java
new file mode 100644
index 00000000000..4c03a23efb9
--- /dev/null
+++ b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/CohereTestConfiguration.java
@@ -0,0 +1,66 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere;
+
+import org.springframework.ai.cohere.api.CohereApi;
+import org.springframework.ai.cohere.chat.CohereChatModel;
+import org.springframework.ai.cohere.chat.CohereChatOptions;
+import org.springframework.ai.cohere.embedding.CohereEmbeddingModel;
+import org.springframework.ai.cohere.embedding.CohereMultimodalEmbeddingModel;
+import org.springframework.boot.SpringBootConfiguration;
+import org.springframework.context.annotation.Bean;
+import org.springframework.util.StringUtils;
+
+/**
+ * @author Ricken Bazolo
+ */
+@SpringBootConfiguration
+public class CohereTestConfiguration {
+
+ private static String retrieveApiKey() {
+ var apiKey = System.getenv("COHERE_API_KEY");
+ if (!StringUtils.hasText(apiKey)) {
+ throw new IllegalArgumentException(
+ "Missing COHERE_API_KEY environment variable. Please set it to your Cohere API key.");
+ }
+ return apiKey;
+ }
+
+ @Bean
+ public CohereApi cohereApi() {
+ return CohereApi.builder().apiKey(retrieveApiKey()).build();
+ }
+
+ @Bean
+ public CohereChatModel cohereChatModel(CohereApi api) {
+ return CohereChatModel.builder()
+ .cohereApi(api)
+ .defaultOptions(CohereChatOptions.builder().model(CohereApi.ChatModel.COMMAND_A.getValue()).build())
+ .build();
+ }
+
+ @Bean
+ public CohereEmbeddingModel cohereEmbeddingModel(CohereApi api) {
+ return CohereEmbeddingModel.builder().cohereApi(api).build();
+ }
+
+ @Bean
+ public CohereMultimodalEmbeddingModel cohereMultimodalEmbeddingModel(CohereApi api) {
+ return CohereMultimodalEmbeddingModel.builder().cohereApi(api).build();
+ }
+
+}
diff --git a/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/aot/CohereRuntimeHintsTests.java b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/aot/CohereRuntimeHintsTests.java
new file mode 100644
index 00000000000..065edf04eeb
--- /dev/null
+++ b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/aot/CohereRuntimeHintsTests.java
@@ -0,0 +1,245 @@
+/*
+ * Copyright 2023-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.aot;
+
+import java.util.HashSet;
+import java.util.Set;
+
+import org.junit.jupiter.api.Test;
+
+import org.springframework.ai.cohere.api.CohereApi;
+import org.springframework.ai.cohere.chat.CohereChatOptions;
+import org.springframework.ai.cohere.embedding.CohereEmbeddingOptions;
+import org.springframework.aot.hint.RuntimeHints;
+import org.springframework.aot.hint.TypeReference;
+
+import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
+import static org.assertj.core.api.AssertionsForClassTypes.assertThatCode;
+import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage;
+
+class CohereRuntimeHintsTests {
+
+ @Test
+ void registerHints() {
+ RuntimeHints runtimeHints = new RuntimeHints();
+ CohereRuntimeHints cohereRuntimeHints = new CohereRuntimeHints();
+ cohereRuntimeHints.registerHints(runtimeHints, null);
+
+ Set jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage("org.springframework.ai.cohere");
+
+ Set registeredTypes = new HashSet<>();
+ runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType()));
+
+ for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) {
+ assertThat(registeredTypes.contains(jsonAnnotatedClass)).isTrue();
+ }
+
+ // Check a few more specific ones
+ assertThat(registeredTypes.contains(TypeReference.of(CohereApi.ChatCompletion.class))).isTrue();
+ assertThat(registeredTypes.contains(TypeReference.of(CohereApi.ChatCompletionChunk.class))).isTrue();
+ assertThat(registeredTypes.contains(TypeReference.of(CohereApi.LogProbs.class))).isTrue();
+ assertThat(registeredTypes.contains(TypeReference.of(CohereApi.ChatCompletionFinishReason.class))).isTrue();
+ assertThat(registeredTypes.contains(TypeReference.of(CohereChatOptions.class))).isTrue();
+ assertThat(registeredTypes.contains(TypeReference.of(CohereEmbeddingOptions.class))).isTrue();
+ }
+
+ @Test
+ void registerHintsWithNullClassLoader() {
+ RuntimeHints runtimeHints = new RuntimeHints();
+ CohereRuntimeHints cohereRuntimeHints = new CohereRuntimeHints();
+
+ // Should not throw exception with null classLoader
+ cohereRuntimeHints.registerHints(runtimeHints, null);
+
+ // Verify hints were registered
+ assertThat(runtimeHints.reflection().typeHints().count()).isGreaterThan(0);
+ }
+
+ @Test
+ void registerHintsWithValidClassLoader() {
+ RuntimeHints runtimeHints = new RuntimeHints();
+ CohereRuntimeHints cohereRuntimeHints = new CohereRuntimeHints();
+ ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
+
+ cohereRuntimeHints.registerHints(runtimeHints, classLoader);
+
+ // Verify hints were registered
+ assertThat(runtimeHints.reflection().typeHints().count()).isGreaterThan(0);
+ }
+
+ @Test
+ void registerHintsIsIdempotent() {
+ RuntimeHints runtimeHints = new RuntimeHints();
+ CohereRuntimeHints cohereRuntimeHints = new CohereRuntimeHints();
+
+ // Register hints twice
+ cohereRuntimeHints.registerHints(runtimeHints, null);
+ long firstCount = runtimeHints.reflection().typeHints().count();
+
+ cohereRuntimeHints.registerHints(runtimeHints, null);
+ long secondCount = runtimeHints.reflection().typeHints().count();
+
+ // Should have same number of hints
+ assertThat(firstCount).isEqualTo(secondCount);
+ }
+
+ @Test
+ void verifyExpectedTypesAreRegistered() {
+ RuntimeHints runtimeHints = new RuntimeHints();
+ CohereRuntimeHints cohereRuntimeHints = new CohereRuntimeHints();
+ cohereRuntimeHints.registerHints(runtimeHints, null);
+
+ Set registeredTypes = new HashSet<>();
+ runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType()));
+
+ // Verify some expected types are registered (adjust class names as needed)
+ assertThat(registeredTypes.stream().anyMatch(tr -> tr.getName().contains("Cohere"))).isTrue();
+ assertThat(registeredTypes.stream().anyMatch(tr -> tr.getName().contains("ChatCompletion"))).isTrue();
+ }
+
+ @Test
+ void verifyPackageScanningWorks() {
+ Set jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage("org.springframework.ai.cohere");
+
+ // Verify package scanning found classes
+ assertThat(jsonAnnotatedClasses.size()).isGreaterThan(0);
+ }
+
+ @Test
+ void verifyAllCriticalApiClassesAreRegistered() {
+ RuntimeHints runtimeHints = new RuntimeHints();
+ CohereRuntimeHints cohereRuntimeHints = new CohereRuntimeHints();
+ cohereRuntimeHints.registerHints(runtimeHints, null);
+
+ Set registeredTypes = new HashSet<>();
+ runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType()));
+
+ // Ensure critical API classes are registered for GraalVM native image reflection
+ String[] criticalClasses = { "CohereApi$ChatCompletionRequest", "CohereApi$ChatCompletionMessage",
+ "CohereApi$EmbeddingRequest", "CohereApi$EmbeddingModel", "CohereApi$Usage" };
+
+ for (String className : criticalClasses) {
+ assertThat(registeredTypes.stream()
+ .anyMatch(tr -> tr.getName().contains(className.replace("$", "."))
+ || tr.getName().contains(className.replace("$", "$"))))
+ .as("Critical class %s should be registered", className)
+ .isTrue();
+ }
+ }
+
+ @Test
+ void verifyEnumTypesAreRegistered() {
+ RuntimeHints runtimeHints = new RuntimeHints();
+ CohereRuntimeHints cohereRuntimeHints = new CohereRuntimeHints();
+ cohereRuntimeHints.registerHints(runtimeHints, null);
+
+ Set registeredTypes = new HashSet<>();
+ runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType()));
+
+ // Enums are critical for JSON deserialization in native images
+ assertThat(registeredTypes.contains(TypeReference.of(CohereApi.ChatModel.class)))
+ .as("ChatModel enum should be registered")
+ .isTrue();
+
+ assertThat(registeredTypes.contains(TypeReference.of(CohereApi.EmbeddingModel.class)))
+ .as("EmbeddingModel enum should be registered")
+ .isTrue();
+ }
+
+ @Test
+ void verifyReflectionHintsIncludeConstructors() {
+ RuntimeHints runtimeHints = new RuntimeHints();
+ CohereRuntimeHints cohereRuntimeHints = new CohereRuntimeHints();
+ cohereRuntimeHints.registerHints(runtimeHints, null);
+
+ // Verify that reflection hints include constructor access
+ boolean hasConstructorHints = runtimeHints.reflection()
+ .typeHints()
+ .anyMatch(typeHint -> typeHint.constructors().findAny().isPresent() || typeHint.getMemberCategories()
+ .contains(org.springframework.aot.hint.MemberCategory.INVOKE_DECLARED_CONSTRUCTORS));
+
+ assertThat(hasConstructorHints).as("Should register constructor hints for JSON deserialization").isTrue();
+ }
+
+ @Test
+ void verifyNoExceptionThrownWithEmptyRuntimeHints() {
+ RuntimeHints emptyRuntimeHints = new RuntimeHints();
+ CohereRuntimeHints cohereRuntimeHints = new CohereRuntimeHints();
+
+ // Should not throw any exception even with empty runtime hints
+ assertThatCode(() -> cohereRuntimeHints.registerHints(emptyRuntimeHints, null)).doesNotThrowAnyException();
+
+ assertThat(emptyRuntimeHints.reflection().typeHints().count()).isGreaterThan(0);
+ }
+
+ @Test
+ void verifyProxyHintsAreNotRegistered() {
+ RuntimeHints runtimeHints = new RuntimeHints();
+ CohereRuntimeHints cohereRuntimeHints = new CohereRuntimeHints();
+ cohereRuntimeHints.registerHints(runtimeHints, null);
+
+ // MistralAi should only register reflection hints, not proxy hints
+ assertThat(runtimeHints.proxies().jdkProxyHints().count()).isEqualTo(0);
+ }
+
+ @Test
+ void verifySerializationHintsAreNotRegistered() {
+ RuntimeHints runtimeHints = new RuntimeHints();
+ CohereRuntimeHints cohereRuntimeHints = new CohereRuntimeHints();
+ cohereRuntimeHints.registerHints(runtimeHints, null);
+
+ // MistralAi should only register reflection hints, not serialization hints
+ assertThat(runtimeHints.serialization().javaSerializationHints().count()).isEqualTo(0);
+ }
+
+ @Test
+ void verifyResponseTypesAreRegistered() {
+ RuntimeHints runtimeHints = new RuntimeHints();
+ CohereRuntimeHints cohereRuntimeHints = new CohereRuntimeHints();
+ cohereRuntimeHints.registerHints(runtimeHints, null);
+
+ Set registeredTypes = new HashSet<>();
+ runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType()));
+
+ // Verify response wrapper types are registered
+ assertThat(registeredTypes.stream().anyMatch(tr -> tr.getName().contains("EmbeddingResponse")))
+ .as("EmbeddingResponse type should be registered")
+ .isTrue();
+
+ assertThat(registeredTypes.stream().anyMatch(tr -> tr.getName().contains("ChatCompletion")))
+ .as("ChatCompletion response type should be registered")
+ .isTrue();
+ }
+
+ @Test
+ void verifyMultipleInstancesRegisterSameHints() {
+ RuntimeHints runtimeHints1 = new RuntimeHints();
+ RuntimeHints runtimeHints2 = new RuntimeHints();
+
+ CohereRuntimeHints hints1 = new CohereRuntimeHints();
+ CohereRuntimeHints hints2 = new CohereRuntimeHints();
+
+ hints1.registerHints(runtimeHints1, null);
+ hints2.registerHints(runtimeHints2, null);
+
+ long count1 = runtimeHints1.reflection().typeHints().count();
+ long count2 = runtimeHints2.reflection().typeHints().count();
+
+ assertThat(count1).isEqualTo(count2);
+ }
+
+}
diff --git a/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/api/CohereApiIT.java b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/api/CohereApiIT.java
new file mode 100644
index 00000000000..656f268ad36
--- /dev/null
+++ b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/api/CohereApiIT.java
@@ -0,0 +1,92 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.api;
+
+import java.util.List;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
+import reactor.core.publisher.Flux;
+
+import org.springframework.ai.cohere.CohereTestConfiguration;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletion;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage.Role;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionRequest;
+import org.springframework.ai.cohere.testutils.AbstractIT;
+import org.springframework.boot.test.context.SpringBootTest;
+import org.springframework.http.ResponseEntity;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * @author Ricken Bazolo
+ */
+@SpringBootTest(classes = CohereTestConfiguration.class)
+@EnabledIfEnvironmentVariable(named = "COHERE_API_KEY", matches = ".+")
+class CohereApiIT extends AbstractIT {
+
+ @Test
+ void chatCompletionEntity() {
+ ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER);
+ ResponseEntity response = this.cohereApi.chatCompletionEntity(new ChatCompletionRequest(
+ List.of(chatCompletionMessage), CohereApi.ChatModel.COMMAND_A_R7B.getValue(), 0.8, false));
+
+ assertThat(response).isNotNull();
+ assertThat(response.getBody()).isNotNull();
+ }
+
+ @Test
+ void chatCompletionEntityWithSystemMessage() {
+ ChatCompletionMessage userMessage = new ChatCompletionMessage(
+ "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did?", Role.USER);
+ ChatCompletionMessage systemMessage = new ChatCompletionMessage("""
+ You are an AI assistant that helps people find information.
+ Your name is Bob.
+ You should reply to the user's request with your name and also in the style of a pirate.
+ """, Role.SYSTEM);
+
+ ResponseEntity response = this.cohereApi.chatCompletionEntity(new ChatCompletionRequest(
+ List.of(systemMessage, userMessage), CohereApi.ChatModel.COMMAND_A_R7B.getValue(), 0.8, false));
+
+ assertThat(response).isNotNull();
+ assertThat(response.getBody()).isNotNull();
+ }
+
+ @Test
+ void embeddings() {
+ ResponseEntity response = this.cohereApi
+ .embeddings(CohereApi.EmbeddingRequest.builder().texts("Hello world").build());
+
+ assertThat(response).isNotNull();
+ Assertions.assertNotNull(response.getBody());
+ assertThat(response.getBody().getFloatEmbeddings()).hasSize(1);
+ assertThat(response.getBody().getFloatEmbeddings().get(0)).hasSize(1536);
+ }
+
+ @Test
+ void chatCompletionStream() {
+ ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER);
+ Flux response = this.cohereApi.chatCompletionStream(new ChatCompletionRequest(
+ List.of(chatCompletionMessage), CohereApi.ChatModel.COMMAND_A_R7B.getValue(), 0.8, true));
+
+ assertThat(response).isNotNull();
+ assertThat(response.collectList().block()).isNotNull();
+ }
+
+}
diff --git a/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/api/tool/CohereApiToolFunctionCallIT.java b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/api/tool/CohereApiToolFunctionCallIT.java
new file mode 100644
index 00000000000..250f78d0bca
--- /dev/null
+++ b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/api/tool/CohereApiToolFunctionCallIT.java
@@ -0,0 +1,158 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.api.tool;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.springframework.ai.cohere.api.CohereApi;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletion;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage.Role;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage.ToolCall;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionRequest;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionRequest.ToolChoice;
+import org.springframework.ai.cohere.api.CohereApi.FunctionTool.Type;
+import org.springframework.ai.model.ModelOptionsUtils;
+import org.springframework.http.ResponseEntity;
+import org.springframework.util.ObjectUtils;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * @author Ricken Bazolo
+ */
+@EnabledIfEnvironmentVariable(named = "COHERE_API_KEY", matches = ".+")
+public class CohereApiToolFunctionCallIT {
+
+ static final String MISTRAL_AI_CHAT_MODEL = CohereApi.ChatModel.COMMAND_A_R7B.getValue();
+
+ private final Logger logger = LoggerFactory.getLogger(CohereApiToolFunctionCallIT.class);
+
+ MockWeatherService weatherService = new MockWeatherService();
+
+ CohereApi completionApi = CohereApi.builder().apiKey(System.getenv("COHERE_API_KEY")).build();
+
+ private static T fromJson(String json, Class targetClass) {
+ try {
+ return new ObjectMapper().readValue(json, targetClass);
+ }
+ catch (JsonProcessingException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Test
+ @SuppressWarnings("null")
+ public void toolFunctionCall() throws JsonProcessingException {
+
+ // Step 1: send the conversation and available functions to the model
+ var message = new ChatCompletionMessage(
+ "What's the weather like in San Francisco, Tokyo, and Paris? Show the temperature in Celsius.",
+ Role.USER);
+
+ var functionTool = new CohereApi.FunctionTool(Type.FUNCTION,
+ new CohereApi.FunctionTool.Function(
+ "Get the weather in location. Return temperature in 30°F or 30°C format.", "getCurrentWeather",
+ ModelOptionsUtils.jsonToMap("""
+ {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The city and state e.g. San Francisco, CA"
+ },
+ "unit": {
+ "type": "string",
+ "enum": ["C", "F"]
+ }
+ },
+ "required": ["location", "unit"]
+ }
+ """)));
+
+ List messages = new ArrayList<>(List.of(message));
+
+ ChatCompletionRequest chatCompletionRequest = new ChatCompletionRequest(messages, MISTRAL_AI_CHAT_MODEL,
+ List.of(functionTool), ToolChoice.REQUIRED);
+
+ System.out
+ .println(new ObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(chatCompletionRequest));
+
+ ResponseEntity response = this.completionApi.chatCompletionEntity(chatCompletionRequest);
+
+ ChatCompletion chatCompletion = response.getBody();
+
+ assertThat(chatCompletion).isNotNull();
+ assertThat(chatCompletion.message()).isNotNull();
+
+ ChatCompletionMessage responseMessage = new ChatCompletionMessage(chatCompletion.message().content(),
+ chatCompletion.message().role(), chatCompletion.message().toolPlan(),
+ chatCompletion.message().toolCalls(), chatCompletion.message().citations(), null);
+
+ assertThat(responseMessage.role()).isEqualTo(Role.ASSISTANT);
+ assertThat(responseMessage.toolCalls()).isNotNull();
+
+ // Check if the model wanted to call a function
+ if (!ObjectUtils.isEmpty(responseMessage.toolCalls())) {
+
+ // extend conversation with assistant's reply.
+ messages.add(responseMessage);
+
+ // Send the info for each function call and function response to the model.
+ for (ToolCall toolCall : responseMessage.toolCalls()) {
+ var functionName = toolCall.function().name();
+ if ("getCurrentWeather".equals(functionName)) {
+ MockWeatherService.Request weatherRequest = fromJson(toolCall.function().arguments(),
+ MockWeatherService.Request.class);
+
+ MockWeatherService.Response weatherResponse = this.weatherService.apply(weatherRequest);
+
+ // extend conversation with function response.
+ messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(),
+ Role.TOOL, functionName, null, responseMessage.citations(), toolCall.id()));
+ }
+ }
+
+ var functionResponseRequest = new ChatCompletionRequest(messages, MISTRAL_AI_CHAT_MODEL, 0.8);
+
+ ResponseEntity result2 = this.completionApi.chatCompletionEntity(functionResponseRequest);
+
+ chatCompletion = result2.getBody();
+
+ logger.info("Final response: {}", chatCompletion);
+
+ assertThat(chatCompletion.message().content()).isNotEmpty();
+
+ var messageContent = chatCompletion.message().content().get(0);
+
+ assertThat(chatCompletion.message().role()).isEqualTo(Role.ASSISTANT);
+ assertThat(messageContent.text()).contains("San Francisco").containsAnyOf("30.0", "30");
+ assertThat(messageContent.text()).contains("Tokyo").containsAnyOf("10.0", "10");
+ assertThat(messageContent.text()).contains("Paris").containsAnyOf("15.0", "15");
+ }
+
+ }
+
+}
diff --git a/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/api/tool/MockWeatherService.java b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/api/tool/MockWeatherService.java
new file mode 100644
index 00000000000..c0b96608a7d
--- /dev/null
+++ b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/api/tool/MockWeatherService.java
@@ -0,0 +1,90 @@
+/*
+ * Copyright 2023-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.api.tool;
+
+import java.util.function.Function;
+
+import com.fasterxml.jackson.annotation.JsonClassDescription;
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.annotation.JsonInclude.Include;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.annotation.JsonPropertyDescription;
+
+public class MockWeatherService implements Function {
+
+ @Override
+ public Response apply(Request request) {
+
+ double temperature = 0;
+ if (request.location().contains("Paris")) {
+ temperature = 15;
+ }
+ else if (request.location().contains("Tokyo")) {
+ temperature = 10;
+ }
+ else if (request.location().contains("San Francisco")) {
+ temperature = 30;
+ }
+
+ return new Response(temperature, 15, 20, 2, 53, 45, Unit.C);
+ }
+
+ /**
+ * Temperature units.
+ */
+ public enum Unit {
+
+ /**
+ * Celsius.
+ */
+ C("metric"),
+ /**
+ * Fahrenheit.
+ */
+ F("imperial");
+
+ /**
+ * Human readable unit name.
+ */
+ public final String unitName;
+
+ Unit(String text) {
+ this.unitName = text;
+ }
+
+ }
+
+ /**
+ * Weather Function request.
+ */
+ @JsonInclude(Include.NON_NULL)
+ @JsonClassDescription("Weather API request")
+ public record Request(@JsonProperty(required = true,
+ value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location,
+ @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) {
+
+ }
+
+ /**
+ * Weather Function response.
+ */
+ public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity,
+ Unit unit) {
+
+ }
+
+}
diff --git a/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/api/tool/PaymentStatusFunctionCallingIT.java b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/api/tool/PaymentStatusFunctionCallingIT.java
new file mode 100644
index 00000000000..2422154b458
--- /dev/null
+++ b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/api/tool/PaymentStatusFunctionCallingIT.java
@@ -0,0 +1,171 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.api.tool;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.springframework.ai.cohere.api.CohereApi;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletion;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage.Role;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage.ToolCall;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionRequest;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionRequest.ToolChoice;
+import org.springframework.ai.cohere.api.CohereApi.FunctionTool;
+import org.springframework.ai.cohere.api.CohereApi.FunctionTool.Type;
+import org.springframework.http.ResponseEntity;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+@EnabledIfEnvironmentVariable(named = "COHERE_API_KEY", matches = ".+")
+public class PaymentStatusFunctionCallingIT {
+
+ // Assuming we have the following data
+ public static final Map DATA = Map.of("T1001", new StatusDate("Paid", "2021-10-05"), "T1002",
+ new StatusDate("Unpaid", "2021-10-06"), "T1003", new StatusDate("Paid", "2021-10-07"), "T1004",
+ new StatusDate("Paid", "2021-10-05"), "T1005", new StatusDate("Pending", "2021-10-08"));
+
+ static Map> functions = Map.of("retrieve_payment_status",
+ new RetrievePaymentStatus(), "retrieve_payment_date", new RetrievePaymentDate());
+
+ private final Logger logger = LoggerFactory.getLogger(PaymentStatusFunctionCallingIT.class);
+
+ private static T jsonToObject(String json, Class targetClass) {
+ try {
+ return new ObjectMapper().readValue(json, targetClass);
+ }
+ catch (JsonProcessingException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Test
+ @SuppressWarnings("null")
+ public void toolFunctionCall() throws JsonProcessingException {
+
+ var transactionJsonSchema = """
+ {
+ "type": "object",
+ "properties": {
+ "transaction_id": {
+ "type": "string",
+ "description": "The transaction id"
+ }
+ },
+ "required": ["transaction_id"]
+ }
+ """;
+
+ var paymentStatusTool = new FunctionTool(Type.FUNCTION, new FunctionTool.Function(
+ "Get payment status of a transaction", "retrieve_payment_status", transactionJsonSchema));
+
+ var paymentDateTool = new FunctionTool(Type.FUNCTION, new FunctionTool.Function(
+ "Get payment date of a transaction", "retrieve_payment_date", transactionJsonSchema));
+
+ List messages = new ArrayList<>(
+ List.of(new ChatCompletionMessage("What's the status of my transaction with id T1001?", Role.USER)));
+
+ CohereApi cohereApi = CohereApi.builder().apiKey(System.getenv("COHERE_API_KEY")).build();
+
+ ResponseEntity response = cohereApi
+ .chatCompletionEntity(new ChatCompletionRequest(messages, CohereApi.ChatModel.COMMAND_A_R7B.getValue(),
+ List.of(paymentStatusTool, paymentDateTool), ToolChoice.REQUIRED));
+
+ ChatCompletion chatCompletion = response.getBody();
+
+ ChatCompletionMessage responseMessage = new ChatCompletionMessage(chatCompletion.message().content(),
+ chatCompletion.message().role(), chatCompletion.message().toolPlan(),
+ chatCompletion.message().toolCalls(), chatCompletion.message().citations(), null);
+
+ assertThat(responseMessage.role()).isEqualTo(Role.ASSISTANT);
+ assertThat(responseMessage.toolCalls()).isNotNull();
+
+ // extend conversation with assistant's reply.
+ messages.add(responseMessage);
+
+ // Send the info for each function call and function response to the model.
+ for (ToolCall toolCall : responseMessage.toolCalls()) {
+
+ var functionName = toolCall.function().name();
+ // Map the function, JSON arguments into a Transaction object.
+ Transaction transaction = jsonToObject(toolCall.function().arguments(), Transaction.class);
+ // Call the target function with the transaction object.
+ var result = functions.get(functionName).apply(transaction);
+
+ // Extend conversation with function response.
+ // The functionName is used to identify the function response!
+ messages.add(new ChatCompletionMessage(result.toString(), Role.TOOL, functionName, null,
+ responseMessage.citations(), toolCall.id()));
+ }
+
+ response = cohereApi
+ .chatCompletionEntity(new ChatCompletionRequest(messages, CohereApi.ChatModel.COMMAND_A_R7B.getValue()));
+
+ chatCompletion = response.getBody();
+ var content = chatCompletion.message().content().get(0).text();
+ logger.info("Final response: {}", content);
+
+ assertThat(content).containsIgnoringCase("T1001");
+ assertThat(content).containsIgnoringCase("Paid");
+ }
+
+ record StatusDate(String status, String date) {
+
+ }
+
+ public record Transaction(@JsonProperty(required = true, value = "transaction_id") String transactionId) {
+
+ }
+
+ public record Status(@JsonProperty(required = true, value = "status") String status) {
+
+ }
+
+ public record Date(@JsonProperty(required = true, value = "date") String date) {
+
+ }
+
+ private static class RetrievePaymentStatus implements Function {
+
+ @Override
+ public Status apply(Transaction paymentTransaction) {
+ return new Status(DATA.get(paymentTransaction.transactionId).status);
+ }
+
+ }
+
+ private static class RetrievePaymentDate implements Function {
+
+ @Override
+ public Date apply(Transaction paymentTransaction) {
+ return new Date(DATA.get(paymentTransaction.transactionId).date);
+ }
+
+ }
+
+}
diff --git a/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/chat/CohereChatClientIT.java b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/chat/CohereChatClientIT.java
new file mode 100644
index 00000000000..7eedb9194cf
--- /dev/null
+++ b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/chat/CohereChatClientIT.java
@@ -0,0 +1,290 @@
+/*
+ * Copyright 2023-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.chat;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import reactor.core.publisher.Flux;
+
+import org.springframework.ai.chat.client.ChatClient;
+import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
+import org.springframework.ai.chat.messages.UserMessage;
+import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.cohere.CohereTestConfiguration;
+import org.springframework.ai.cohere.api.CohereApi;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionRequest.ToolChoice;
+import org.springframework.ai.cohere.api.tool.MockWeatherService;
+import org.springframework.ai.cohere.testutils.AbstractIT;
+import org.springframework.ai.converter.BeanOutputConverter;
+import org.springframework.ai.converter.ListOutputConverter;
+import org.springframework.ai.test.CurlyBracketEscaper;
+import org.springframework.ai.tool.function.FunctionToolCallback;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.boot.test.context.SpringBootTest;
+import org.springframework.core.ParameterizedTypeReference;
+import org.springframework.core.convert.support.DefaultConversionService;
+import org.springframework.core.io.Resource;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+@SpringBootTest(classes = CohereTestConfiguration.class)
+@EnabledIfEnvironmentVariable(named = "COHERE_API_KEY", matches = ".+")
+class CohereChatClientIT extends AbstractIT {
+
+ private static final Logger logger = LoggerFactory.getLogger(CohereChatClientIT.class);
+
+ @Value("classpath:/prompts/system-message.st")
+ private Resource systemTextResource;
+
+ @Test
+ void call() {
+ // @formatter:off
+ ChatResponse response = ChatClient.create(this.chatModel).prompt()
+ .system(s -> s.text(this.systemTextResource)
+ .param("name", "Bob")
+ .param("voice", "pirate"))
+ .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did")
+ .call()
+ .chatResponse();
+ // @formatter:on
+
+ logger.info("{}", response);
+ assertThat(response.getResults()).hasSize(1);
+ assertThat(response.getResults().get(0).getOutput().getText()).contains("Blackbeard");
+ }
+
+ @Test
+ void testMessageHistory() {
+
+ // @formatter:off
+ ChatResponse response = ChatClient.create(this.chatModel).prompt()
+ .system(s -> s.text(this.systemTextResource)
+ .param("name", "Bob")
+ .param("voice", "pirate"))
+ .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did")
+ .call()
+ .chatResponse();
+ // @formatter:on
+ assertThat(response.getResult().getOutput().getText()).containsAnyOf("Blackbeard");
+
+ // @formatter:off
+ response = ChatClient.create(this.chatModel).prompt()
+ .messages(List.of(new UserMessage("Dummy"), response.getResult().getOutput()))
+ .user("Repeat the last assistant message.")
+ .call()
+ .chatResponse();
+ // @formatter:on
+
+ logger.info("" + response);
+ assertThat(response.getResult().getOutput().getText().toLowerCase()).containsAnyOf("blackbeard",
+ "bartholomew roberts");
+ }
+
+ @Test
+ void listOutputConverterString() {
+ // @formatter:off
+ List collection = ChatClient.create(this.chatModel).prompt()
+ .user(u -> u.text("List five {subject}")
+ .param("subject", "ice cream flavors"))
+ .call()
+ .entity(new ParameterizedTypeReference<>() { });
+ // @formatter:on
+
+ logger.info(collection.toString());
+ assertThat(collection).hasSize(5);
+ }
+
+ @Test
+ void listOutputConverterBean() {
+
+ // @formatter:off
+ List actorsFilms = ChatClient.create(this.chatModel).prompt()
+ .user("Generate the filmography of 5 movies for Tom Hanks and Bill Murray.")
+ .call()
+ .entity(new ParameterizedTypeReference<>() {
+ });
+ // @formatter:on
+
+ logger.info("" + actorsFilms);
+ assertThat(actorsFilms).hasSize(2);
+ }
+
+ @Test
+ void customOutputConverter() {
+
+ var toStringListConverter = new ListOutputConverter(new DefaultConversionService());
+
+ // @formatter:off
+ List flavors = ChatClient.create(this.chatModel).prompt()
+ .user(u -> u.text("List 10 {subject}")
+ .param("subject", "ice cream flavors"))
+ .call()
+ .entity(toStringListConverter);
+ // @formatter:on
+
+ logger.info("ice cream flavors{}", flavors);
+ assertThat(flavors).hasSize(10);
+ assertThat(flavors).containsAnyOf("Vanilla", "vanilla");
+ }
+
+ @Test
+ void mapOutputConverter() {
+ // @formatter:off
+ Map result = ChatClient.create(this.chatModel).prompt()
+ .user(u -> u.text("Provide me a List of {subject}")
+ .param("subject", "an array of numbers from 1 to 9 under they key name 'numbers'"))
+ .call()
+ .entity(new ParameterizedTypeReference<>() {
+ });
+ // @formatter:on
+
+ assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
+ }
+
+ @Test
+ void beanOutputConverter() {
+
+ // @formatter:off
+ ActorsFilms actorsFilms = ChatClient.create(this.chatModel).prompt()
+ .user("Generate the filmography for a random actor.")
+ .call()
+ .entity(ActorsFilms.class);
+ // @formatter:on
+
+ logger.info("{}", actorsFilms);
+ assertThat(actorsFilms.actor()).isNotBlank();
+ }
+
+ @Test
+ void beanOutputConverterRecords() {
+
+ // @formatter:off
+ ActorsFilms actorsFilms = ChatClient.create(this.chatModel).prompt()
+ .user("Generate the filmography of 5 movies for Tom Hanks.")
+ .call()
+ .entity(ActorsFilms.class);
+ // @formatter:on
+
+ logger.info("{}", actorsFilms);
+ assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
+ assertThat(actorsFilms.movies()).hasSize(5);
+ }
+
+ @Test
+ void beanStreamOutputConverterRecords() {
+
+ BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilms.class);
+
+ // @formatter:off
+ Flux chatResponse = ChatClient.create(this.chatModel)
+ .prompt()
+ .advisors(new SimpleLoggerAdvisor())
+ .user(u -> u
+ .text("Generate the filmography of 5 movies for Tom Hanks. " + System.lineSeparator()
+ + "{format}")
+ .param("format", CurlyBracketEscaper.escapeCurlyBrackets(outputConverter.getFormat())))
+ .stream()
+ .content();
+
+ String generationTextFromStream = chatResponse.collectList()
+ .block()
+ .stream()
+ .collect(Collectors.joining());
+ // @formatter:on
+
+ ActorsFilms actorsFilms = outputConverter.convert(generationTextFromStream);
+
+ logger.info("{}", actorsFilms);
+ assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
+ assertThat(actorsFilms.movies()).hasSize(5);
+ }
+
+ @Test
+ void functionCallTest() {
+
+ // @formatter:off
+ String response = ChatClient.create(this.chatModel).prompt()
+ .options(CohereChatOptions.builder().model(CohereApi.ChatModel.COMMAND_A_R7B).toolChoice(ToolChoice.REQUIRED).build())
+ .user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius."))
+ .toolCallbacks(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
+ .description("Get the weather in location")
+ .inputType(MockWeatherService.Request.class)
+ .build())
+ .call()
+ .content();
+ // @formatter:on
+
+ logger.info("Response: {}", response);
+
+ assertThat(response).containsAnyOf("30.0", "30");
+ assertThat(response).containsAnyOf("10.0", "10");
+ assertThat(response).containsAnyOf("15.0", "15");
+ }
+
+ @Test
+ void streamFunctionCallTest() {
+
+ // @formatter:off
+ Flux response = ChatClient.create(this.chatModel).prompt()
+ .options(CohereChatOptions.builder().model(CohereApi.ChatModel.COMMAND_A_R7B).build())
+ .user("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius.")
+ .toolCallbacks(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
+ .description("Get the weather in location")
+ .inputType(MockWeatherService.Request.class)
+ .build())
+ .stream()
+ .content();
+ // @formatter:on
+
+ String content = response.collectList().block().stream().collect(Collectors.joining());
+ logger.info("Response: {}", content);
+
+ assertThat(content).containsAnyOf("30.0", "30");
+ assertThat(content).containsAnyOf("10.0", "10");
+ assertThat(content).containsAnyOf("15.0", "15");
+ }
+
+ @Test
+ void validateCallResponseMetadata() {
+ String model = CohereApi.ChatModel.COMMAND_A_R7B.getName();
+ // @formatter:off
+ ChatResponse response = ChatClient.create(this.chatModel).prompt()
+ .options(CohereChatOptions.builder().model(model).build())
+ .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did")
+ .call()
+ .chatResponse();
+ // @formatter:on
+
+ logger.info(response.toString());
+ assertThat(response.getMetadata().getId()).isNotEmpty();
+ assertThat(response.getMetadata().getUsage().getPromptTokens()).isPositive();
+ assertThat(response.getMetadata().getUsage().getCompletionTokens()).isPositive();
+ assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive();
+ }
+
+ record ActorsFilms(String actor, List movies) {
+
+ }
+
+}
diff --git a/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/chat/CohereChatCompletionRequestTests.java b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/chat/CohereChatCompletionRequestTests.java
new file mode 100644
index 00000000000..80dda662b01
--- /dev/null
+++ b/models/spring-ai-cohere/src/test/java/org/springframework/ai/cohere/chat/CohereChatCompletionRequestTests.java
@@ -0,0 +1,316 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.cohere.chat;
+
+import java.net.URI;
+import java.util.List;
+import java.util.Map;
+
+import org.junit.jupiter.api.Test;
+
+import org.springframework.ai.chat.messages.AbstractMessage;
+import org.springframework.ai.chat.messages.AssistantMessage;
+import org.springframework.ai.chat.messages.Message;
+import org.springframework.ai.chat.messages.MessageType;
+import org.springframework.ai.chat.messages.SystemMessage;
+import org.springframework.ai.chat.messages.ToolResponseMessage;
+import org.springframework.ai.chat.messages.UserMessage;
+import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.cohere.api.CohereApi;
+import org.springframework.ai.cohere.api.CohereApi.ChatCompletionMessage;
+import org.springframework.ai.content.Media;
+import org.springframework.ai.model.tool.ToolCallingChatOptions;
+import org.springframework.ai.tool.ToolCallback;
+import org.springframework.ai.tool.definition.DefaultToolDefinition;
+import org.springframework.ai.tool.definition.ToolDefinition;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/**
+ * @author Ricken Bazolo
+ */
+class CohereChatCompletionRequestTests {
+
+ private static final String BASE_URL = "https://faked.url";
+
+ private static final String API_KEY = "FAKED_API_KEY";
+
+ private static final String TEXT_CONTENT = "Hello world!";
+
+ private static final String IMAGE_URL = "https://example.com/image.png";
+
+ private static final Media IMAGE_MEDIA = new Media(Media.Format.IMAGE_PNG, URI.create(IMAGE_URL));
+
+ private final CohereChatModel chatModel = CohereChatModel.builder()
+ .cohereApi(CohereApi.builder().baseUrl(BASE_URL).apiKey(API_KEY).build())
+ .build();
+
+ @Test
+ void chatCompletionDefaultRequestTest() {
+ var prompt = this.chatModel.buildRequestPrompt(new Prompt("test content"));
+ var request = this.chatModel.createRequest(prompt, false);
+
+ assertThat(request.messages()).hasSize(1);
+ assertThat(request.temperature()).isEqualTo(0.3);
+ assertThat(request.p()).isEqualTo(1);
+ assertThat(request.maxTokens()).isNull();
+ assertThat(request.stream()).isFalse();
+ }
+
+ @Test
+ void chatCompletionRequestWithOptionsTest() {
+ var options = CohereChatOptions.builder().temperature(0.5).topP(0.8).build();
+ var prompt = this.chatModel.buildRequestPrompt(new Prompt("test content", options));
+ var request = this.chatModel.createRequest(prompt, true);
+
+ assertThat(request.messages()).hasSize(1);
+ assertThat(request.p()).isEqualTo(0.8);
+ assertThat(request.temperature()).isEqualTo(0.5);
+ assertThat(request.stream()).isTrue();
+ }
+
+ @Test
+ void whenToolRuntimeOptionsThenMergeWithDefaults() {
+ CohereChatOptions defaultOptions = CohereChatOptions.builder()
+ .model("DEFAULT_MODEL")
+ .internalToolExecutionEnabled(true)
+ .toolCallbacks(new TestToolCallback("tool1"), new TestToolCallback("tool2"))
+ .toolNames("tool1", "tool2")
+ .toolContext(Map.of("key1", "value1", "key2", "valueA"))
+ .build();
+
+ CohereChatModel anotherChatModel = CohereChatModel.builder()
+ .cohereApi(CohereApi.builder().baseUrl(BASE_URL).apiKey(API_KEY).build())
+ .defaultOptions(defaultOptions)
+ .build();
+
+ CohereChatOptions runtimeOptions = CohereChatOptions.builder()
+ .internalToolExecutionEnabled(false)
+ .toolCallbacks(new TestToolCallback("tool3"), new TestToolCallback("tool4"))
+ .toolNames("tool3")
+ .toolContext(Map.of("key2", "valueB"))
+ .build();
+ Prompt prompt = anotherChatModel.buildRequestPrompt(new Prompt("Test message content", runtimeOptions));
+
+ assertThat(((ToolCallingChatOptions) prompt.getOptions())).isNotNull();
+ assertThat(((ToolCallingChatOptions) prompt.getOptions()).getInternalToolExecutionEnabled()).isFalse();
+ assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks()).hasSize(2);
+ assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks()
+ .stream()
+ .map(toolCallback -> toolCallback.getToolDefinition().name())).containsExactlyInAnyOrder("tool3", "tool4");
+ assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolNames()).containsExactlyInAnyOrder("tool3");
+ assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolContext()).containsEntry("key1", "value1")
+ .containsEntry("key2", "valueB");
+ }
+
+ @Test
+ void createChatCompletionMessagesWithUserMessage() {
+ var userMessage = new UserMessage(TEXT_CONTENT);
+ userMessage.getMedia().add(IMAGE_MEDIA);
+ var prompt = createPrompt(userMessage);
+ var chatCompletionRequest = this.chatModel.createRequest(prompt, false);
+ verifyUserChatCompletionMessages(chatCompletionRequest.messages());
+ }
+
+ @Test
+ void createChatCompletionMessagesWithSimpleUserMessage() {
+ var simpleUserMessage = new SimpleMessage(MessageType.USER, TEXT_CONTENT);
+ var prompt = createPrompt(simpleUserMessage);
+ var chatCompletionRequest = this.chatModel.createRequest(prompt, false);
+ var chatCompletionMessages = chatCompletionRequest.messages();
+ assertThat(chatCompletionMessages).hasSize(1);
+ var chatCompletionMessage = chatCompletionMessages.get(0);
+ assertThat(chatCompletionMessage.role()).isEqualTo(ChatCompletionMessage.Role.USER);
+ assertThat(chatCompletionMessage.content()).isEqualTo(TEXT_CONTENT);
+ }
+
+ @Test
+ void createChatCompletionMessagesWithSystemMessage() {
+ var systemMessage = new SystemMessage(TEXT_CONTENT);
+ var prompt = createPrompt(systemMessage);
+ var chatCompletionRequest = this.chatModel.createRequest(prompt, false);
+ verifySystemChatCompletionMessages(chatCompletionRequest.messages());
+ }
+
+ @Test
+ void createChatCompletionMessagesWithSimpleSystemMessage() {
+ var simpleSystemMessage = new SimpleMessage(MessageType.SYSTEM, TEXT_CONTENT);
+ var prompt = createPrompt(simpleSystemMessage);
+ var chatCompletionRequest = this.chatModel.createRequest(prompt, false);
+ verifySystemChatCompletionMessages(chatCompletionRequest.messages());
+ }
+
+ @Test
+ void createChatCompletionMessagesWithAssistantMessage() {
+ var toolCall1 = createToolCall(1);
+ var toolCall2 = createToolCall(2);
+ var toolCall3 = createToolCall(3);
+ // @formatter:off
+ var assistantMessage = AssistantMessage.builder()
+ .content(TEXT_CONTENT)
+ .toolCalls(List.of(toolCall1, toolCall2, toolCall3))
+ .build();
+ // @formatter:on
+ var prompt = createPrompt(assistantMessage);
+ var chatCompletionRequest = this.chatModel.createRequest(prompt, false);
+ var chatCompletionMessages = chatCompletionRequest.messages();
+ assertThat(chatCompletionMessages).hasSize(1);
+ var chatCompletionMessage = chatCompletionMessages.get(0);
+ assertThat(chatCompletionMessage.role()).isEqualTo(ChatCompletionMessage.Role.ASSISTANT);
+ assertThat(chatCompletionMessage.content()).isEqualTo(TEXT_CONTENT);
+ var toolCalls = chatCompletionMessage.toolCalls();
+ assertThat(toolCalls).hasSize(3);
+ verifyToolCall(toolCalls.get(0), toolCall1);
+ verifyToolCall(toolCalls.get(1), toolCall2);
+ verifyToolCall(toolCalls.get(2), toolCall3);
+ }
+
+ @Test
+ void createChatCompletionMessagesWithSimpleAssistantMessage() {
+ var simpleAssistantMessage = new SimpleMessage(MessageType.ASSISTANT, TEXT_CONTENT);
+ var prompt = createPrompt(simpleAssistantMessage);
+ assertThatThrownBy(() -> this.chatModel.createRequest(prompt, false))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("Unsupported assistant message class: " + SimpleMessage.class.getName());
+ }
+
+ @Test
+ void createChatCompletionMessagesWithToolResponseMessage() {
+ var toolResponse1 = createToolResponse(1);
+ var toolResponse2 = createToolResponse(2);
+ var toolResponse3 = createToolResponse(3);
+ var toolResponseMessage = ToolResponseMessage.builder()
+ .responses(List.of(toolResponse1, toolResponse2, toolResponse3))
+ .build();
+ var prompt = createPrompt(toolResponseMessage);
+ var chatCompletionRequest = this.chatModel.createRequest(prompt, false);
+ var chatCompletionMessages = chatCompletionRequest.messages();
+ assertThat(chatCompletionMessages).hasSize(3);
+ verifyToolChatCompletionMessage(chatCompletionMessages.get(0), toolResponse1);
+ verifyToolChatCompletionMessage(chatCompletionMessages.get(1), toolResponse2);
+ verifyToolChatCompletionMessage(chatCompletionMessages.get(2), toolResponse3);
+ }
+
+ @Test
+ void createChatCompletionMessagesWithInvalidToolResponseMessage() {
+ var toolResponse = new ToolResponseMessage.ToolResponse(null, null, null);
+ var toolResponseMessage = ToolResponseMessage.builder().responses(List.of(toolResponse)).build();
+ var prompt = createPrompt(toolResponseMessage);
+ assertThatThrownBy(() -> this.chatModel.createRequest(prompt, false))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("ToolResponseMessage.ToolResponse must have an id");
+ }
+
+ @Test
+ void createChatCompletionMessagesWithSimpleToolMessage() {
+ var simpleToolMessage = new SimpleMessage(MessageType.TOOL, TEXT_CONTENT);
+ var prompt = createPrompt(simpleToolMessage);
+ assertThatThrownBy(() -> this.chatModel.createRequest(prompt, false))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("Unsupported tool message class: " + SimpleMessage.class.getName());
+ }
+
+ private Prompt createPrompt(Message message) {
+ var chatOptions = CohereChatOptions.builder().temperature(0.7d).build();
+ var prompt = new Prompt(message, chatOptions);
+
+ return this.chatModel.buildRequestPrompt(prompt);
+ }
+
+ private static void verifyToolChatCompletionMessage(ChatCompletionMessage chatCompletionMessage,
+ ToolResponseMessage.ToolResponse toolResponse) {
+ assertThat(chatCompletionMessage.role()).isEqualTo(ChatCompletionMessage.Role.TOOL);
+ assertThat(chatCompletionMessage.content()).isEqualTo(toolResponse.responseData());
+ assertThat(chatCompletionMessage.toolCalls()).isNull();
+ assertThat(chatCompletionMessage.toolCallId()).isEqualTo(toolResponse.id());
+ }
+
+ private static ToolResponseMessage.ToolResponse createToolResponse(int number) {
+ return new ToolResponseMessage.ToolResponse("id" + number, "name" + number, "responseData" + number);
+ }
+
+ private static void verifyToolCall(ChatCompletionMessage.ToolCall mistralToolCall,
+ AssistantMessage.ToolCall toolCall) {
+ assertThat(mistralToolCall.id()).isEqualTo(toolCall.id());
+ assertThat(mistralToolCall.type()).isEqualTo(toolCall.type());
+ var function = mistralToolCall.function();
+ assertThat(function).isNotNull();
+ assertThat(function.name()).isEqualTo(toolCall.name());
+ assertThat(function.arguments()).isEqualTo(toolCall.arguments());
+ }
+
+ private static AssistantMessage.ToolCall createToolCall(int number) {
+ return new AssistantMessage.ToolCall("id" + number, "type" + number, "name" + number, "arguments " + number);
+ }
+
+ private static void verifySystemChatCompletionMessages(List chatCompletionMessages) {
+ assertThat(chatCompletionMessages).hasSize(1);
+ var chatCompletionMessage = chatCompletionMessages.get(0);
+ assertThat(chatCompletionMessage.role()).isEqualTo(ChatCompletionMessage.Role.SYSTEM);
+ assertThat(chatCompletionMessage.content()).isEqualTo(TEXT_CONTENT);
+ }
+
+ private static void verifyUserChatCompletionMessages(List chatCompletionMessages) {
+ assertThat(chatCompletionMessages).hasSize(1);
+ var chatCompletionMessage = chatCompletionMessages.get(0);
+ assertThat(chatCompletionMessage.role()).isEqualTo(ChatCompletionMessage.Role.USER);
+ var rawContent = chatCompletionMessage.rawContent();
+ assertThat(rawContent).isNotNull();
+ var maps = (List