diff --git a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingModel.java b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingModel.java index e3658a226a7..c88bb675af2 100644 --- a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingModel.java +++ b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingModel.java @@ -22,10 +22,8 @@ import java.util.concurrent.atomic.AtomicInteger; import com.oracle.bmc.generativeaiinference.GenerativeAiInference; -import com.oracle.bmc.generativeaiinference.model.DedicatedServingMode; import com.oracle.bmc.generativeaiinference.model.EmbedTextDetails; import com.oracle.bmc.generativeaiinference.model.EmbedTextResult; -import com.oracle.bmc.generativeaiinference.model.OnDemandServingMode; import com.oracle.bmc.generativeaiinference.model.ServingMode; import com.oracle.bmc.generativeaiinference.requests.EmbedTextRequest; import io.micrometer.observation.ObservationRegistry; @@ -128,15 +126,6 @@ private EmbeddingResponse embedAllWithContext(List embedTextRe return embeddingResponse; } - private ServingMode servingMode(OCIEmbeddingOptions embeddingOptions) { - return switch (embeddingOptions.getServingMode()) { - case "dedicated" -> DedicatedServingMode.builder().endpointId(embeddingOptions.getModel()).build(); - case "on-demand" -> OnDemandServingMode.builder().modelId(embeddingOptions.getModel()).build(); - default -> throw new IllegalArgumentException( - "unknown serving mode for OCI embedding model: " + embeddingOptions.getServingMode()); - }; - } - private List createRequests(List inputs, OCIEmbeddingOptions embeddingOptions) { int size = inputs.size(); List requests = new ArrayList<>(); @@ -148,8 +137,9 @@ private List createRequests(List inputs, OCIEmbeddingO } private EmbedTextRequest createRequest(List inputs, OCIEmbeddingOptions embeddingOptions) { + ServingMode servingMode = ServingModeHelper.get(options.getServingMode(), options.getModel()); EmbedTextDetails embedTextDetails = EmbedTextDetails.builder() - .servingMode(servingMode(embeddingOptions)) + .servingMode(servingMode) .compartmentId(embeddingOptions.getCompartment()) .inputs(inputs) .truncate(Objects.requireNonNullElse(embeddingOptions.getTruncate(), EmbedTextDetails.Truncate.End)) diff --git a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/ServingModeHelper.java b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/ServingModeHelper.java new file mode 100644 index 00000000000..e621c69ac6a --- /dev/null +++ b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/ServingModeHelper.java @@ -0,0 +1,40 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.oci; + +import com.oracle.bmc.generativeaiinference.model.DedicatedServingMode; +import com.oracle.bmc.generativeaiinference.model.OnDemandServingMode; +import com.oracle.bmc.generativeaiinference.model.ServingMode; + +/** + * Helper class to load the OCI Gen AI + * {@link com.oracle.bmc.generativeaiinference.model.ServingMode} + * + * @author Anders Swanson + */ +public final class ServingModeHelper { + + public static ServingMode get(String servingMode, String model) { + return switch (servingMode) { + case "dedicated" -> DedicatedServingMode.builder().endpointId(model).build(); + case "on-demand" -> OnDemandServingMode.builder().modelId(model).build(); + default -> throw new IllegalArgumentException(String.format( + "Unknown serving mode for OCI Gen AI: %s. Supported options are 'dedicated' and 'on-demand'", + servingMode)); + }; + } + +} diff --git a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatModel.java b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatModel.java new file mode 100644 index 00000000000..4f8267c1646 --- /dev/null +++ b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatModel.java @@ -0,0 +1,248 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.oci.cohere; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import com.oracle.bmc.generativeaiinference.GenerativeAiInference; +import com.oracle.bmc.generativeaiinference.model.BaseChatRequest; +import com.oracle.bmc.generativeaiinference.model.BaseChatResponse; +import com.oracle.bmc.generativeaiinference.model.ChatDetails; +import com.oracle.bmc.generativeaiinference.model.CohereChatBotMessage; +import com.oracle.bmc.generativeaiinference.model.CohereChatRequest; +import com.oracle.bmc.generativeaiinference.model.CohereChatResponse; +import com.oracle.bmc.generativeaiinference.model.CohereMessage; +import com.oracle.bmc.generativeaiinference.model.CohereSystemMessage; +import com.oracle.bmc.generativeaiinference.model.CohereToolCall; +import com.oracle.bmc.generativeaiinference.model.CohereToolMessage; +import com.oracle.bmc.generativeaiinference.model.CohereToolResult; +import com.oracle.bmc.generativeaiinference.model.CohereUserMessage; +import com.oracle.bmc.generativeaiinference.model.ServingMode; +import com.oracle.bmc.generativeaiinference.requests.ChatRequest; +import io.micrometer.observation.ObservationRegistry; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.observation.ChatModelObservationContext; +import org.springframework.ai.chat.observation.ChatModelObservationConvention; +import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; +import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.observation.conventions.AiProvider; +import org.springframework.ai.oci.ServingModeHelper; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +import static java.util.Objects.requireNonNullElse; + +/** + * {@link ChatModel} implementation that uses the OCI GenAI Chat API. + * + * @author Anders Swanson + * @since 1.0.0 + */ +public class OCICohereChatModel implements ChatModel { + + private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); + + private static final Double DEFAULT_TEMPERATURE = 0.7; + + /** + * The {@link GenerativeAiInference} client used to interact with OCI GenAI service. + */ + private final GenerativeAiInference genAi; + + /** + * The configuration information for a chat completions request. + */ + private final OCICohereChatOptions defaultOptions; + + private final ObservationRegistry observationRegistry; + + /** + * Conventions to use for generating observations. + */ + private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + + public OCICohereChatModel(GenerativeAiInference genAi, OCICohereChatOptions options) { + this(genAi, options, null); + } + + public OCICohereChatModel(GenerativeAiInference genAi, OCICohereChatOptions options, + ObservationRegistry observationRegistry) { + Assert.notNull(genAi, "com.oracle.bmc.generativeaiinference.GenerativeAiInference must not be null"); + Assert.notNull(options, "OCIChatOptions must not be null"); + + this.genAi = genAi; + this.defaultOptions = options; + this.observationRegistry = observationRegistry; + } + + @Override + public ChatResponse call(Prompt prompt) { + ChatModelObservationContext observationContext = ChatModelObservationContext.builder() + .prompt(prompt) + .provider(AiProvider.OCI_GENAI.value()) + .requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions) + .build(); + + return ChatModelObservationDocumentation.CHAT_MODEL_OPERATION + .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry) + .observe(() -> { + ChatResponse chatResponse = doChatRequest(prompt); + observationContext.setResponse(chatResponse); + return chatResponse; + }); + } + + @Override + public ChatOptions getDefaultOptions() { + return OCICohereChatOptions.fromOptions(defaultOptions); + } + + /** + * Use the provided convention for reporting observation data + * @param observationConvention The provided convention + */ + public void setObservationConvention(ChatModelObservationConvention observationConvention) { + Assert.notNull(observationConvention, "observationConvention cannot be null"); + this.observationConvention = observationConvention; + } + + private ChatResponse doChatRequest(Prompt prompt) { + OCICohereChatOptions options = mergeOptions(prompt.getOptions(), this.defaultOptions); + validateChatOptions(options); + + ChatResponseMetadata metadata = ChatResponseMetadata.builder() + .withModel(options.getModel()) + .withKeyValue("compartment", options.getCompartment()) + .build(); + return new ChatResponse(getGenerations(prompt, options), metadata); + + } + + private OCICohereChatOptions mergeOptions(ChatOptions chatOptions, OCICohereChatOptions defaultOptions) { + if (chatOptions instanceof OCICohereChatOptions override) { + OCICohereChatOptions dynamicOptions = ModelOptionsUtils.merge(override, defaultOptions, + OCICohereChatOptions.class); + + if (dynamicOptions != null) { + return dynamicOptions; + } + } + return defaultOptions; + } + + private void validateChatOptions(OCICohereChatOptions options) { + if (!StringUtils.hasText(options.getModel())) { + throw new IllegalArgumentException("Model is not set!"); + } + if (!StringUtils.hasText(options.getCompartment())) { + throw new IllegalArgumentException("Compartment is not set!"); + } + if (!StringUtils.hasText(options.getServingMode())) { + throw new IllegalArgumentException("ServingMode is not set!"); + } + } + + private List getGenerations(Prompt prompt, OCICohereChatOptions options) { + com.oracle.bmc.generativeaiinference.responses.ChatResponse cr = genAi + .chat(toCohereChatRequest(prompt, options)); + return toGenerations(cr, options); + + } + + private List toGenerations(com.oracle.bmc.generativeaiinference.responses.ChatResponse ociChatResponse, + OCICohereChatOptions options) { + BaseChatResponse cr = ociChatResponse.getChatResult().getChatResponse(); + if (cr instanceof CohereChatResponse resp) { + List generations = new ArrayList<>(); + ChatGenerationMetadata metadata = ChatGenerationMetadata.from(resp.getFinishReason().getValue(), null); + AssistantMessage message = new AssistantMessage(resp.getText(), Map.of()); + generations.add(new Generation(message, metadata)); + return generations; + } + throw new IllegalStateException(String.format("Unexpected chat response type: %s", cr.getClass().getName())); + } + + private ChatRequest toCohereChatRequest(Prompt prompt, OCICohereChatOptions options) { + List messages = prompt.getInstructions(); + Message message = messages.get(0); + List chatHistory = getCohereMessages(messages); + return newChatRequest(options, message, chatHistory); + } + + private List getCohereMessages(List messages) { + List chatHistory = new ArrayList<>(); + for (int i = 1; i < messages.size(); i++) { + Message message = messages.get(i); + switch (message.getMessageType()) { + case USER -> chatHistory.add(CohereUserMessage.builder().message(message.getContent()).build()); + case ASSISTANT -> chatHistory.add(CohereChatBotMessage.builder().message(message.getContent()).build()); + case SYSTEM -> chatHistory.add(CohereSystemMessage.builder().message(message.getContent()).build()); + case TOOL -> { + if (message instanceof ToolResponseMessage tm) { + chatHistory.add(toToolMessage(tm)); + } + } + } + } + return chatHistory; + } + + private CohereToolMessage toToolMessage(ToolResponseMessage tm) { + List results = tm.getResponses().stream().map(r -> { + CohereToolCall call = CohereToolCall.builder().name(r.name()).build(); + return CohereToolResult.builder().call(call).outputs(List.of(r.responseData())).build(); + }).toList(); + return CohereToolMessage.builder().toolResults(results).build(); + } + + private ChatRequest newChatRequest(OCICohereChatOptions options, Message message, List chatHistory) { + BaseChatRequest baseChatRequest = CohereChatRequest.builder() + .frequencyPenalty(options.getFrequencyPenalty()) + .presencePenalty(options.getPresencePenalty()) + .maxTokens(options.getMaxTokens()) + .topK(options.getTopK()) + .topP(options.getTopP()) + .temperature(requireNonNullElse(options.getTemperature(), DEFAULT_TEMPERATURE)) + .preambleOverride(options.getPreambleOverride()) + .stopSequences(options.getStopSequences()) + .documents(options.getDocuments()) + .tools(options.getTools()) + .chatHistory(chatHistory) + .message(message.getContent()) + .build(); + ServingMode servingMode = ServingModeHelper.get(options.getServingMode(), options.getModel()); + ChatDetails chatDetails = ChatDetails.builder() + .compartmentId(options.getCompartment()) + .servingMode(servingMode) + .chatRequest(baseChatRequest) + .build(); + return ChatRequest.builder().body$(chatDetails).build(); + } + +} diff --git a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatOptions.java b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatOptions.java new file mode 100644 index 00000000000..77e5a32ee45 --- /dev/null +++ b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatOptions.java @@ -0,0 +1,348 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.oci.cohere; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.oracle.bmc.generativeaiinference.model.CohereTool; +import org.springframework.ai.chat.prompt.ChatOptions; + +/** + * The configuration information for OCI chat requests + * + * @author Anders Swanson + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class OCICohereChatOptions implements ChatOptions { + + @JsonProperty("model") + private String model; + + /** + * The maximum number of tokens to generate per request. + */ + @JsonProperty("maxTokens") + private Integer maxTokens; + + /** + * The OCI Compartment to run chat requests in. + */ + @JsonProperty("compartment") + private String compartment; + + /** + * The serving mode of OCI Gen AI model used. May be "on-demand" or "dedicated". + */ + @JsonProperty("servingMode") + private String servingMode; + + /** + * The optional override to the chat model's prompt preamble. + */ + @JsonProperty("preambleOverride") + private String preambleOverride; + + /** + * The sample temperature, where higher values are more random, and lower values are + * more deterministic. + */ + @JsonProperty(value = "temperature") + private Double temperature; + + /** + * The Top P parameter modifies the probability of tokens sampled. E.g., a value of + * 0.25 means only tokens from the top 25% probability mass will be considered. + */ + @JsonProperty("topP") + private Double topP; + + /** + * The Top K parameter limits the number of potential tokens considered at each step + * of text generation. E.g., a value of 5 means only the top 5 most probable tokens + * will be considered during each step of text generation. + */ + @JsonProperty("topK") + private Integer topK; + + /** + * The frequency penalty assigns a penalty to repeated tokens depending on how many + * times it has already appeared in the prompt or output. Higher values will reduce + * repeated tokens and outputs will be more random. + */ + @JsonProperty("frequencyPenalty") + private Double frequencyPenalty; + + /** + * The presence penalty assigns a penalty to each token when it appears in the output + * to encourage generating outputs with tokens that haven't been used. + */ + @JsonProperty("presencePenalty") + private Double presencePenalty; + + /** + * A collection of textual sequences that will end completions generation. + */ + @JsonProperty("stop") + private List stop; + + /** + * Documents for chat context. + */ + @JsonProperty("documents") + private List documents; + + /** + * Tools for the chat bot. + */ + @JsonProperty("tools") + private List tools; + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + protected OCICohereChatOptions chatOptions; + + public Builder() { + this.chatOptions = new OCICohereChatOptions(); + } + + public Builder(OCICohereChatOptions chatOptions) { + this.chatOptions = chatOptions; + } + + public Builder withModel(String model) { + this.chatOptions.model = model; + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.chatOptions.maxTokens = maxTokens; + return this; + } + + public Builder withCompartment(String compartment) { + this.chatOptions.compartment = compartment; + return this; + } + + public Builder withServingMode(String servingMode) { + this.chatOptions.servingMode = servingMode; + return this; + } + + public Builder withPreambleOverride(String preambleOverride) { + this.chatOptions.preambleOverride = preambleOverride; + return this; + } + + public Builder withTemperature(Double temperature) { + this.chatOptions.temperature = temperature; + return this; + } + + public Builder withTopP(Double topP) { + this.chatOptions.topP = topP; + return this; + } + + public Builder withTopK(Integer topK) { + this.chatOptions.topK = topK; + return this; + } + + public Builder withFrequencyPenalty(Double frequencyPenalty) { + this.chatOptions.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder withPresencePenalty(Double presencePenalty) { + this.chatOptions.presencePenalty = presencePenalty; + return this; + } + + public Builder withStop(List stop) { + this.chatOptions.stop = stop; + return this; + } + + public Builder withDocuments(List documents) { + this.chatOptions.documents = documents; + return this; + } + + public Builder withTools(List tools) { + this.chatOptions.tools = tools; + return this; + } + + public OCICohereChatOptions build() { + return this.chatOptions; + } + + } + + public static OCICohereChatOptions fromOptions(OCICohereChatOptions fromOptions) { + return builder().withModel(fromOptions.model) + .withMaxTokens(fromOptions.maxTokens) + .withCompartment(fromOptions.compartment) + .withServingMode(fromOptions.servingMode) + .withPreambleOverride(fromOptions.preambleOverride) + .withTemperature(fromOptions.temperature) + .withTopP(fromOptions.topP) + .withTopK(fromOptions.topK) + .withStop(fromOptions.stop) + .withFrequencyPenalty(fromOptions.frequencyPenalty) + .withPresencePenalty(fromOptions.presencePenalty) + .withDocuments(fromOptions.documents) + .withTools(fromOptions.tools) + .build(); + } + + /* + * Getters and setters + */ + + public void setPresencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public void setFrequencyPenalty(Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + public void setTopK(Integer topK) { + this.topK = topK; + } + + public void setTopP(Double topP) { + this.topP = topP; + } + + public void setTemperature(Double temperature) { + this.temperature = temperature; + } + + public String getPreambleOverride() { + return preambleOverride; + } + + public void setPreambleOverride(String preambleOverride) { + this.preambleOverride = preambleOverride; + } + + public String getServingMode() { + return servingMode; + } + + public void setServingMode(String servingMode) { + this.servingMode = servingMode; + } + + public String getCompartment() { + return compartment; + } + + public void setCompartment(String compartment) { + this.compartment = compartment; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + public void setModel(String model) { + this.model = model; + } + + public List getStop() { + return this.stop; + } + + public void setStop(List stop) { + this.stop = stop; + } + + public List getDocuments() { + return documents; + } + + public void setDocuments(List documents) { + this.documents = documents; + } + + public List getTools() { + return tools; + } + + public void setTools(List tools) { + this.tools = tools; + } + + /* + * ChatModel overrides. + */ + + @Override + public String getModel() { + return model; + } + + @Override + public Double getFrequencyPenalty() { + return frequencyPenalty; + } + + @Override + public Integer getMaxTokens() { + return maxTokens; + } + + @Override + public Double getPresencePenalty() { + return presencePenalty; + } + + @Override + public List getStopSequences() { + return stop; + } + + @Override + public Double getTemperature() { + return temperature; + } + + @Override + public Integer getTopK() { + return topK; + } + + @Override + public Double getTopP() { + return topP; + } + + @Override + public ChatOptions copy() { + return fromOptions(this); + } + +} diff --git a/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseEmbeddingModelTest.java b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseEmbeddingModelTest.java index b1f6da89b35..77b71896b49 100644 --- a/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseEmbeddingModelTest.java +++ b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseEmbeddingModelTest.java @@ -16,50 +16,23 @@ package org.springframework.ai.oci; -import java.io.IOException; -import java.nio.file.Paths; - -import com.oracle.bmc.Region; -import com.oracle.bmc.auth.ConfigFileAuthenticationDetailsProvider; -import com.oracle.bmc.generativeaiinference.GenerativeAiInferenceClient; - -public class BaseEmbeddingModelTest { - - public static final String OCI_COMPARTMENT_ID_KEY = "OCI_COMPARTMENT_ID"; +public class BaseEmbeddingModelTest extends BaseOCIGenAITest { public static final String EMBEDDING_MODEL_V2 = "cohere.embed-english-light-v2.0"; public static final String EMBEDDING_MODEL_V3 = "cohere.embed-english-v3.0"; - private static final String CONFIG_FILE = Paths.get(System.getProperty("user.home"), ".oci", "config").toString(); - - private static final String PROFILE = "DEFAULT"; - - private static final String REGION = "us-chicago-1"; - - private static final String COMPARTMENT_ID = System.getenv(OCI_COMPARTMENT_ID_KEY); - /** * Create an OCIEmbeddingModel instance using a config file authentication provider. * @return OCIEmbeddingModel instance */ - public static OCIEmbeddingModel get() { - try { - ConfigFileAuthenticationDetailsProvider authProvider = new ConfigFileAuthenticationDetailsProvider( - CONFIG_FILE, PROFILE); - GenerativeAiInferenceClient aiClient = GenerativeAiInferenceClient.builder() - .region(Region.valueOf(REGION)) - .build(authProvider); - OCIEmbeddingOptions options = OCIEmbeddingOptions.builder() - .withModel(EMBEDDING_MODEL_V2) - .withCompartment(COMPARTMENT_ID) - .withServingMode("on-demand") - .build(); - return new OCIEmbeddingModel(aiClient, options); - } - catch (IOException e) { - throw new RuntimeException(e); - } + public static OCIEmbeddingModel getEmbeddingModel() { + OCIEmbeddingOptions options = OCIEmbeddingOptions.builder() + .withModel(EMBEDDING_MODEL_V2) + .withCompartment(COMPARTMENT_ID) + .withServingMode("on-demand") + .build(); + return new OCIEmbeddingModel(getGenerativeAIClient(), options); } } diff --git a/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseOCIGenAITest.java b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseOCIGenAITest.java new file mode 100644 index 00000000000..358563cb063 --- /dev/null +++ b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseOCIGenAITest.java @@ -0,0 +1,61 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.oci; + +import java.io.IOException; +import java.nio.file.Paths; + +import com.oracle.bmc.Region; +import com.oracle.bmc.auth.ConfigFileAuthenticationDetailsProvider; +import com.oracle.bmc.generativeaiinference.GenerativeAiInference; +import com.oracle.bmc.generativeaiinference.GenerativeAiInferenceClient; +import org.springframework.ai.oci.cohere.OCICohereChatOptions; + +public class BaseOCIGenAITest { + + public static final String OCI_COMPARTMENT_ID_KEY = "OCI_COMPARTMENT_ID"; + + public static final String OCI_CHAT_MODEL_ID_KEY = "OCI_CHAT_MODEL_ID"; + + public static final String CONFIG_FILE = Paths.get(System.getProperty("user.home"), ".oci", "config").toString(); + + public static final String PROFILE = "DEFAULT"; + + public static final String REGION = "us-chicago-1"; + + public static final String COMPARTMENT_ID = System.getenv(OCI_COMPARTMENT_ID_KEY); + + public static final String CHAT_MODEL_ID = System.getenv(OCI_CHAT_MODEL_ID_KEY); + + public static GenerativeAiInference getGenerativeAIClient() { + try { + ConfigFileAuthenticationDetailsProvider authProvider = new ConfigFileAuthenticationDetailsProvider( + CONFIG_FILE, PROFILE); + return GenerativeAiInferenceClient.builder().region(Region.valueOf(REGION)).build(authProvider); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + public static OCICohereChatOptions.Builder options() { + return OCICohereChatOptions.builder() + .withModel(CHAT_MODEL_ID) + .withCompartment(COMPARTMENT_ID) + .withServingMode("on-demand"); + } + +} diff --git a/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/OCIEmbeddingModelIT.java b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/OCIEmbeddingModelIT.java index 586fbfddeeb..d136c4999f2 100644 --- a/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/OCIEmbeddingModelIT.java +++ b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/OCIEmbeddingModelIT.java @@ -31,7 +31,7 @@ matches = ".+") public class OCIEmbeddingModelIT extends BaseEmbeddingModelTest { - private final OCIEmbeddingModel embeddingModel = get(); + private final OCIEmbeddingModel embeddingModel = getEmbeddingModel(); private final List content = List.of("How many states are in the USA?", "How many states are in India?"); diff --git a/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/cohere/OCICohereChatModelIT.java b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/cohere/OCICohereChatModelIT.java new file mode 100644 index 00000000000..f12684caab8 --- /dev/null +++ b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/cohere/OCICohereChatModelIT.java @@ -0,0 +1,59 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.oci.cohere; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.oci.BaseOCIGenAITest; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.oci.BaseOCIGenAITest.OCI_CHAT_MODEL_ID_KEY; +import static org.springframework.ai.oci.BaseOCIGenAITest.OCI_COMPARTMENT_ID_KEY; + +@EnabledIfEnvironmentVariable(named = OCI_COMPARTMENT_ID_KEY, matches = ".+") +@EnabledIfEnvironmentVariable(named = OCI_CHAT_MODEL_ID_KEY, matches = ".+") +public class OCICohereChatModelIT extends BaseOCIGenAITest { + + private static final ChatModel chatModel = new OCICohereChatModel(getGenerativeAIClient(), options().build()); + + @Test + void chatSimple() { + String response = chatModel.call("Tell me a random fact about Canada"); + assertThat(response).isNotBlank(); + } + + @Test + void chatMessages() { + String response = chatModel.call(new UserMessage("Tell me a random fact about the Arctic Circle"), + new SystemMessage("You are a helpful assistant")); + assertThat(response).isNotBlank(); + } + + @Test + void chatPrompt() { + ChatResponse response = chatModel.call(new Prompt("What's the difference between Top P and Top K sampling?")); + assertThat(response).isNotNull(); + assertThat(response.getMetadata().getModel()).isEqualTo(CHAT_MODEL_ID); + assertThat(response.getResult()).isNotNull(); + assertThat(response.getResult().getOutput().getContent()).isNotBlank(); + } + +} diff --git a/pom.xml b/pom.xml index 9ea06c8a067..d0ef0d8d5be 100644 --- a/pom.xml +++ b/pom.xml @@ -180,7 +180,7 @@ 0.30.0 1.19.2 - 3.46.1 + 3.51.0 26.48.0 1.9.1 2.0.9 diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index 5427dd487dc..eeeae3a1084 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -31,7 +31,8 @@ //// **** xref:api/chat/functions/moonshot-chat-functions.adoc[Function Calling] *** xref:api/chat/nvidia-chat.adoc[NVIDIA] *** xref:api/chat/ollama-chat.adoc[Ollama] -**** xref:api/chat/functions/ollama-chat-functions.adoc[Function Calling] +*** OCI Generative AI +**** xref:api/chat/oci-genai/cohere-chat.adoc[Cohere] *** xref:api/chat/openai-chat.adoc[OpenAI] **** xref:api/chat/functions/openai-chat-functions.adoc[Function Calling] *** xref:api/chat/qianfan-chat.adoc[QianFan] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/comparison.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/comparison.adoc index df9cb03d53d..3c989029215 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/comparison.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/comparison.adoc @@ -19,25 +19,26 @@ This table compares various Chat Models supported by Spring AI, detailing their |==== | Provider | Multimodality ^| Tools/Functions ^| Streaming ^| Retry ^| Observability ^| Built-in JSON ^| Local ^| OpenAI API Compatible -| xref::api/chat/anthropic-chat.adoc[Anthropic Claude] | text, image ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] -| xref::api/chat/azure-openai-chat.adoc[Azure OpenAI] | text, image ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] -| xref::api/chat/vertexai-gemini-chat.adoc[Google VertexAI Gemini] | text, image, audio, video ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] -| xref::api/chat/vertexai-palm2-chat.adoc[Google VertexAI PaML2] | text ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] -| xref::api/chat/groq-chat.adoc[Groq (OpenAI-proxy)] | text, image ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] -| xref::api/chat/huggingface.adoc[HuggingFace] | text ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] -| xref::api/chat/mistralai-chat.adoc[Mistral AI] | text ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] -| xref::api/chat/minimax-chat.adoc[MiniMax] | text ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| +| xref::api/chat/anthropic-chat.adoc[Anthropic Claude] | text, image ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] +| xref::api/chat/azure-openai-chat.adoc[Azure OpenAI] | text, image ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] +| xref::api/chat/vertexai-gemini-chat.adoc[Google VertexAI Gemini] | text, image, audio, video ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] +| xref::api/chat/vertexai-palm2-chat.adoc[Google VertexAI PaML2] | text ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] +| xref::api/chat/groq-chat.adoc[Groq (OpenAI-proxy)] | text, image ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] +| xref::api/chat/huggingface.adoc[HuggingFace] | text ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] +| xref::api/chat/mistralai-chat.adoc[Mistral AI] | text ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] +| xref::api/chat/minimax-chat.adoc[MiniMax] | text ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| | xref::api/chat/moonshot-chat.adoc[Moonshot AI] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| -| xref::api/chat/nvidia-chat.adoc[NVIDIA (OpenAI-proxy)] | text, image ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] -| xref::api/chat/ollama-chat.adoc[Ollama] | text, image ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] -| xref::api/chat/openai-chat.adoc[OpenAI] | text, image ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] -| xref::api/chat/qianfan-chat.adoc[QianFan] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] -| xref::api/chat/zhipuai-chat.adoc[ZhiPu AI] | text ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] -| xref::api/chat/watsonx-ai-chat.adoc[Watsonx.AI] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] -| xref::api/chat/bedrock/bedrock-cohere.adoc[Amazon Bedrock/Cohere] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] -| xref::api/chat/bedrock/bedrock-jurassic2.adoc[Amazon Bedrock/Jurassic] | text ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] -| xref::api/chat/bedrock/bedrock-llama.adoc[Amazon Bedrock/Llama] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] -| xref::api/chat/bedrock/bedrock-titan.adoc[Amazon Bedrock/Titan] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] -| xref::api/chat/bedrock/bedrock-anthropic3.adoc[Amazon Bedrock/Anthropic 3] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] +| xref::api/chat/nvidia-chat.adoc[NVIDIA (OpenAI-proxy)] | text, image ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] +| xref::api/chat/oci-genai/cohere-chat.adoc[OCI GenAI/Cohere] | text ^a| image::no.svg[width=16] ^a| image::no.svg[width=16] ^a| image::no.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=16] ^a| image::no.svg[width=16] ^a| image::no.svg[width=16] +| xref::api/chat/ollama-chat.adoc[Ollama] | text, image ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] +| xref::api/chat/openai-chat.adoc[OpenAI] | text, image ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] +| xref::api/chat/qianfan-chat.adoc[QianFan] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] +| xref::api/chat/zhipuai-chat.adoc[ZhiPu AI] | text ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] +| xref::api/chat/watsonx-ai-chat.adoc[Watsonx.AI] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] +| xref::api/chat/bedrock/bedrock-cohere.adoc[Amazon Bedrock/Cohere] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] +| xref::api/chat/bedrock/bedrock-jurassic2.adoc[Amazon Bedrock/Jurassic] | text ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] +| xref::api/chat/bedrock/bedrock-llama.adoc[Amazon Bedrock/Llama] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] +| xref::api/chat/bedrock/bedrock-titan.adoc[Amazon Bedrock/Titan] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] +| xref::api/chat/bedrock/bedrock-anthropic3.adoc[Amazon Bedrock/Anthropic 3] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] |==== diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/oci-genai/cohere-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/oci-genai/cohere-chat.adoc new file mode 100644 index 00000000000..5cfd220c42a --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/oci-genai/cohere-chat.adoc @@ -0,0 +1,208 @@ += OCI GenAI Cohere Chat + +https://www.oracle.com/artificial-intelligence/generative-ai/generative-ai-service/[OCI GenAI Service] offers generative AI chat with on-demand models, or dedicated AI clusters. + +The https://docs.oracle.com/en-us/iaas/Content/generative-ai/chat-models.htm[OCI Chat Models Page] and https://docs.oracle.com/en-us/iaas/Content/generative-ai/use-playground-embed.htm[OCI Generative AI Playground] provide detailed information about using and hosting chat models on OCI. + +== Prerequisites + +You will need an active https://signup.oraclecloud.com/[Oracle Cloud Infrastructure (OCI)] account to use the OCI GenAI Cohere Chat client. The client offers four different ways to connect, including simple authentication with a user and private key, workload identity, instance principal, or OCI configuration file authentication. + +=== Add Repositories and BOM + +Spring AI artifacts are published in Spring Milestone and Snapshot repositories. +Refer to the xref:getting-started.adoc#repositories[Repositories] section to add these repositories to your build system. + +To help with dependency management, Spring AI provides a BOM (bill of materials) to ensure that a consistent version of Spring AI is used throughout the entire project. Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build system. + +== Auto-configuration + +Spring AI provides Spring Boot auto-configuration for the OCI GenAI Cohere Chat Client. +To enable it add the following dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-oci-genai-spring-boot-starter + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-oci-genai-spring-boot-starter' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +=== Chat Properties + +==== Connection Properties + +The prefix `spring.ai.oci.genai` is the property prefix to configure the connection to OCI GenAI. + +[cols="3,5,1", stripes=even] +|==== +| Property | Description | Default + +| spring.ai.oci.genai.authenticationType | The type of authentication to use when authenticating to OCI. May be `file`, `instance-principal`, `workload-identity`, or `simple`. | file +| spring.ai.oci.genai.region | OCI service region. | us-chicago-1 +| spring.ai.oci.genai.tenantId | OCI tenant OCID, used when authenticating with `simple` auth. | - +| spring.ai.oci.genai.userId | OCI user OCID, used when authenticating with `simple` auth. | - +| spring.ai.oci.genai.fingerprint | Private key fingerprint, used when authenticating with `simple` auth. | - +| spring.ai.oci.genai.privateKey | Private key content, used when authenticating with `simple` auth. | - +| spring.ai.oci.genai.passPhrase | Optional private key passphrase, used when authenticating with `simple` auth and a passphrase protected private key. | - +| spring.ai.oci.genai.file | Path to OCI config file. Used when authenticating with `file` auth. | /.oci/config +| spring.ai.oci.genai.profile | OCI profile name. Used when authenticating with `file` auth. | DEFAULT +| spring.ai.oci.genai.endpoint | Optional OCI GenAI endpoint. | - + +|==== + + +==== Configuration Properties + +The prefix `spring.ai.oci.genai.chat.cohere` is the property prefix that configures the `ChatModel` implementation for OCI GenAI Cohere Chat. + +[cols="3,5,1", stripes=even] +|==== +| Property | Description | Default + +| spring.ai.oci.genai.chat.cohere.enabled | Enable OCI GenAI Cohere chat model. | true +| spring.ai.oci.genai.chat.cohere.options.model | Model OCID or endpoint | - +| spring.ai.oci.genai.chat.cohere.options.compartment | Model compartment OCID. | - +| spring.ai.oci.genai.chat.cohere.options.servingMode | The model serving mode to be used. May be `on-demand`, or `dedicated`. | on-demand +| spring.ai.oci.genai.chat.cohere.options.preambleOverride | Override the chat model's prompt preamble | - +| spring.ai.oci.genai.chat.cohere.options.temperature | Inference temperature | - +| spring.ai.oci.genai.chat.cohere.options.topP | Top P parameter | - +| spring.ai.oci.genai.chat.cohere.options.topK | Top K parameter | - +| spring.ai.oci.genai.chat.cohere.options.frequencyPenalty | Higher values will reduce repeated tokens and outputs will be more random. | - +| spring.ai.oci.genai.chat.cohere.options.presencePenalty | Higher values encourage generating outputs with tokens that haven't been used. | - +| spring.ai.oci.genai.chat.cohere.options.stop | List of textual sequences that will end completions generation. | - +| spring.ai.oci.genai.chat.cohere.options.documents | List of documents used in chat context. | - +|==== + +TIP: All properties prefixed with `spring.ai.oci.genai.chat.cohere.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call. + +== Runtime Options [[chat-options]] + +The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatOptions.java[OCICohereChatOptions.java] provides model configurations, such as the model to use, the temperature, the frequency penalty, etc. + +On start-up, the default options can be configured with the `OCICohereChatModel(api, options)` constructor or the `spring.ai.oci.genai.chat.cohere.options.*` properties. + +At run-time you can override the default options by adding new, request specific, options to the `Prompt` call. +For example to override the default model and temperature for a specific request: + +[source,java] +---- +ChatResponse response = chatModel.call( + new Prompt( + "Generate the names of 5 famous pirates.", + OCICohereChatOptions.builder() + .withModel("my-model-ocid") + .withCompartment("my-compartment-ocid") + .withTemperature(0.5) + .build() + )); +---- + +== Sample Controller + +https://start.spring.io/[Create] a new Spring Boot project and add the `spring-ai-oci-genai-spring-boot-starter` to your pom (or gradle) dependencies. + +Add a `application.properties` file, under the `src/main/resources` directory, to enable and configure the OCI GenAI Cohere chat model: + +[source,application.properties] +---- +spring.ai.oci.genai.authenticationType=file +spring.ai.oci.genai.file=/path/to/oci/config/file +spring.ai.oci.genai.cohere.chat.options.compartment=my-compartment-ocid +spring.ai.oci.genai.cohere.chat.options.servingMode=on-demand +spring.ai.oci.genai.cohere.chat.options.model=my-chat-model-ocid +---- + +TIP: replace the `file`, `compartment`, and `model` with your values from your OCI account. + +This will create a `OCICohereChatModel` implementation that you can inject into your class. +Here is an example of a simple `@Controller` class that uses the chat model for text generations. + +[source,java] +---- +@RestController +public class ChatController { + + private final OCICohereChatModel chatModel; + + @Autowired + public ChatController(OCICohereChatModel chatModel) { + this.chatModel = chatModel; + } + + @GetMapping("/ai/generate") + public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + return Map.of("generation", chatModel.call(message)); + } + + @GetMapping("/ai/generateStream") + public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + var prompt = new Prompt(new UserMessage(message)); + return chatModel.stream(prompt); + } +} +---- + +== Manual Configuration +The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatModel.java[OCICohereChatModel] implements the `ChatModel` and uses the OCI Java SDK to connect to the OCI GenAI service. + +Add the `spring-ai-oci-genai` dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-oci-genai + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-oci-genai' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +Next, create a `OCICohereChatModel` and use it for text generations: + +[source,java] +---- +var CONFIG_FILE = Paths.get(System.getProperty("user.home"), ".oci", "config").toString(); +var COMPARTMENT_ID = System.getenv("OCI_COMPARTMENT_ID"); +var MODEL_ID = System.getenv("OCI_CHAT_MODEL_ID"); + +ConfigFileAuthenticationDetailsProvider authProvider = new ConfigFileAuthenticationDetailsProvider( + CONFIG_FILE, + "DEFAULT" +); +var genAi = GenerativeAiInferenceClient.builder() + .region(Region.valueOf("us-chicago-1")) + .build(authProvider); + +var chatModel = new OCICohereChatModel(genAi, OCICohereChatOptions.builder() + .withModel(MODEL_ID) + .withCompartment(COMPARTMENT_ID) + .withServingMode("on-demand") + .build()); + +ChatResponse response = chatModel.call( + new Prompt("Generate the names of 5 famous pirates.")); +---- + +The `OCICohereChatOptions` provides the configuration information for the chat requests. +The `OCICohereChatOptions.Builder` is fluent options builder. diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCICohereChatModelProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCICohereChatModelProperties.java new file mode 100644 index 00000000000..cdd4fcb711b --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCICohereChatModelProperties.java @@ -0,0 +1,58 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.oci.genai; + +import org.springframework.ai.oci.cohere.OCICohereChatOptions; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +/** + * @author Anders Swanson + */ +@ConfigurationProperties(OCICohereChatModelProperties.CONFIG_PREFIX) +public class OCICohereChatModelProperties { + + public static final String CONFIG_PREFIX = "spring.ai.oci.genai.cohere.chat"; + + private static final String DEFAULT_SERVING_MODE = ServingMode.ON_DEMAND.getMode(); + + private static final Double DEFAULT_TEMPERATURE = 0.7; + + private boolean enabled; + + @NestedConfigurationProperty + private OCICohereChatOptions options = OCICohereChatOptions.builder() + .withServingMode(DEFAULT_SERVING_MODE) + .withTemperature(DEFAULT_TEMPERATURE) + .build(); + + public boolean isEnabled() { + return enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + + public OCICohereChatOptions getOptions() { + return options; + } + + public void setOptions(OCICohereChatOptions options) { + this.options = options; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfiguration.java index 681ee71f3c3..fb1cfbfc8b0 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfiguration.java @@ -28,8 +28,11 @@ import com.oracle.bmc.auth.okeworkloadidentity.OkeWorkloadIdentityAuthenticationDetailsProvider; import com.oracle.bmc.generativeaiinference.GenerativeAiInferenceClient; import com.oracle.bmc.retrier.RetryConfiguration; - +import io.micrometer.observation.ObservationRegistry; +import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.oci.OCIEmbeddingModel; +import org.springframework.ai.oci.cohere.OCICohereChatModel; +import org.springframework.beans.factory.ObjectProvider; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; @@ -43,7 +46,8 @@ */ @AutoConfiguration @ConditionalOnClass({ GenerativeAiInferenceClient.class, OCIEmbeddingModel.class }) -@EnableConfigurationProperties({ OCIConnectionProperties.class, OCIEmbeddingModelProperties.class }) +@EnableConfigurationProperties({ OCIConnectionProperties.class, OCIEmbeddingModelProperties.class, + OCICohereChatModelProperties.class, }) public class OCIGenAiAutoConfiguration { private static BasicAuthenticationDetailsProvider authenticationProvider(OCIConnectionProperties properties) @@ -89,4 +93,17 @@ public OCIEmbeddingModel ociEmbeddingModel(GenerativeAiInferenceClient generativ return new OCIEmbeddingModel(generativeAiClient, properties.getEmbeddingOptions()); } + @Bean + @ConditionalOnProperty(prefix = OCICohereChatModelProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", + matchIfMissing = true) + public OCICohereChatModel ociChatModel(GenerativeAiInferenceClient generativeAiClient, + OCICohereChatModelProperties properties, ObjectProvider observationRegistry, + ObjectProvider observationConvention) { + var chatModel = new OCICohereChatModel(generativeAiClient, properties.getOptions(), + observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)); + observationConvention.ifAvailable(chatModel::setObservationConvention); + + return chatModel; + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAIAutoConfigurationTest.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAIAutoConfigurationTest.java new file mode 100644 index 00000000000..51badac0a82 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAIAutoConfigurationTest.java @@ -0,0 +1,91 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.oci.genai; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.security.KeyPair; +import java.security.KeyPairGenerator; + +import com.oracle.bmc.http.client.pki.Pem; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.springframework.ai.oci.cohere.OCICohereChatModel; +import org.springframework.ai.oci.cohere.OCICohereChatOptions; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +public class OCIGenAIAutoConfigurationTest { + + @Test + void setProperties(@TempDir Path tempDir) throws Exception { + Path tmp = tempDir.resolve("my-key.pem"); + createPrivateKey(tmp); + ApplicationContextRunner contextRunner = new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.oci.genai.authenticationType=simple", + "spring.ai.oci.genai.userId=my-user", + "spring.ai.oci.genai.tenantId=my-tenant", + "spring.ai.oci.genai.fingerprint=xyz", + "spring.ai.oci.genai.privateKey=" + tmp.toAbsolutePath(), + "spring.ai.oci.genai.region=us-ashburn-1", + "spring.ai.oci.genai.cohere.chat.options.compartment=my-compartment", + "spring.ai.oci.genai.cohere.chat.options.servingMode=dedicated", + "spring.ai.oci.genai.cohere.chat.options.model=my-model", + "spring.ai.oci.genai.cohere.chat.options.maxTokens=1000", + "spring.ai.oci.genai.cohere.chat.options.temperature=0.5", + "spring.ai.oci.genai.cohere.chat.options.topP=0.8", + "spring.ai.oci.genai.cohere.chat.options.maxTokens=1000", + "spring.ai.oci.genai.cohere.chat.options.frequencyPenalty=0.1", + "spring.ai.oci.genai.cohere.chat.options.presencePenalty=0.2" + // @formatter:on + ).withConfiguration(AutoConfigurations.of(OCIGenAiAutoConfiguration.class)); + + contextRunner.run(context -> { + OCICohereChatModel chatModel = context.getBean(OCICohereChatModel.class); + assertThat(chatModel).isNotNull(); + OCICohereChatOptions options = (OCICohereChatOptions) chatModel.getDefaultOptions(); + assertThat(options.getCompartment()).isEqualTo("my-compartment"); + assertThat(options.getModel()).isEqualTo("my-model"); + assertThat(options.getServingMode()).isEqualTo("dedicated"); + assertThat(options.getMaxTokens()).isEqualTo(1000); + assertThat(options.getTemperature()).isEqualTo(0.5); + assertThat(options.getTopP()).isEqualTo(0.8); + assertThat(options.getFrequencyPenalty()).isEqualTo(0.1); + assertThat(options.getPresencePenalty()).isEqualTo(0.2); + + OCIConnectionProperties props = context.getBean(OCIConnectionProperties.class); + assertThat(props.getAuthenticationType()).isEqualTo(OCIConnectionProperties.AuthenticationType.SIMPLE); + assertThat(props.getUserId()).isEqualTo("my-user"); + assertThat(props.getTenantId()).isEqualTo("my-tenant"); + assertThat(props.getFingerprint()).isEqualTo("xyz"); + assertThat(props.getPrivateKey()).isEqualTo(tmp.toAbsolutePath().toString()); + assertThat(props.getRegion()).isEqualTo("us-ashburn-1"); + + }); + } + + private void createPrivateKey(Path tmp) throws Exception { + KeyPairGenerator gen = KeyPairGenerator.getInstance("RSA"); + gen.initialize(2048); + KeyPair keyPair = gen.generateKeyPair(); + byte[] encoded = Pem.encoder().encode(keyPair.getPrivate()); + Files.write(tmp, encoded); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfigurationIT.java index d23681cad1f..37c9b4bd9c2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfigurationIT.java @@ -25,6 +25,7 @@ import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.oci.OCIEmbeddingModel; +import org.springframework.ai.oci.cohere.OCICohereChatModel; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -35,11 +36,15 @@ public class OCIGenAiAutoConfigurationIT { public static final String COMPARTMENT_ID_KEY = "OCI_COMPARTMENT_ID"; + public static final String OCI_CHAT_MODEL_ID_KEY = "OCI_CHAT_MODEL_ID"; + private final String CONFIG_FILE = Paths.get(System.getProperty("user.home"), ".oci", "config").toString(); private final String COMPARTMENT_ID = System.getenv(COMPARTMENT_ID_KEY); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withPropertyValues( + private final String CHAT_MODEL_ID = System.getenv(OCI_CHAT_MODEL_ID_KEY); + + private final ApplicationContextRunner embeddingContextRunner = new ApplicationContextRunner().withPropertyValues( // @formatter:off "spring.ai.oci.genai.authenticationType=file", "spring.ai.oci.genai.file=" + this.CONFIG_FILE, @@ -49,9 +54,19 @@ public class OCIGenAiAutoConfigurationIT { // @formatter:on ).withConfiguration(AutoConfigurations.of(OCIGenAiAutoConfiguration.class)); + private final ApplicationContextRunner cohereChatContextRunner = new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.oci.genai.authenticationType=file", + "spring.ai.oci.genai.file=" + CONFIG_FILE, + "spring.ai.oci.genai.cohere.chat.options.compartment=" + COMPARTMENT_ID, + "spring.ai.oci.genai.cohere.chat.options.servingMode=on-demand", + "spring.ai.oci.genai.cohere.chat.options.model=" + CHAT_MODEL_ID + // @formatter:on + ).withConfiguration(AutoConfigurations.of(OCIGenAiAutoConfiguration.class)); + @Test void embeddings() { - this.contextRunner.run(context -> { + embeddingContextRunner.run(context -> { OCIEmbeddingModel embeddingModel = context.getBean(OCIEmbeddingModel.class); assertThat(embeddingModel).isNotNull(); EmbeddingResponse response = embeddingModel @@ -61,4 +76,15 @@ void embeddings() { }); } + @Test + @EnabledIfEnvironmentVariable(named = OCIGenAiAutoConfigurationIT.OCI_CHAT_MODEL_ID_KEY, matches = ".+") + void cohereChat() { + cohereChatContextRunner.run(context -> { + OCICohereChatModel chatModel = context.getBean(OCICohereChatModel.class); + assertThat(chatModel).isNotNull(); + String response = chatModel.call("How many states are in the United States of America?"); + assertThat(response).isNotBlank(); + }); + } + }