From 683b9f9a0f2477feb55c096512c3f1c72e466b81 Mon Sep 17 00:00:00 2001 From: Tomasz Forys Date: Fri, 28 Nov 2025 00:38:14 +0100 Subject: [PATCH] GH-4985: MessageChatMemoryAdvisor with conversationId supplier Implementation of GH-4985 (https://github.com/spring-projects/spring-ai/issues/4985) * conversationId supplier support on MessageChatMemoryAdvisor Signed-off-by: Tomasz Forys --- .../advisor/MessageChatMemoryAdvisorIT.java | 66 ++++++++++++++++++- .../advisor/MessageChatMemoryAdvisor.java | 30 ++++++--- 2 files changed, 86 insertions(+), 10 deletions(-) diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/MessageChatMemoryAdvisorIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/MessageChatMemoryAdvisorIT.java index 7f5fa1aa94c..f49e1b97fd5 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/MessageChatMemoryAdvisorIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/MessageChatMemoryAdvisorIT.java @@ -18,6 +18,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.concurrent.atomic.AtomicReference; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; @@ -133,6 +134,70 @@ void shouldHandleMultipleUserMessagesInPrompt() { assertThat(followUpAnswer).containsIgnoringCase("David"); } + /** + * Tests that the advisor correctly uses a conversation ID supplier when provided. + */ + @Test + protected void testUseSupplierConversationId() { + // Arrange + ChatMemory chatMemory = MessageWindowChatMemory.builder() + .chatMemoryRepository(new InMemoryChatMemoryRepository()) + .build(); + + // ConversationId circular iterator + String firstConversationId = "conversationId-1"; + String secondConversationId = "conversationId-2"; + AtomicReference conversationIdHolder = new AtomicReference<>(firstConversationId); + + // Create advisor with conversation id supplier returning conversationId interchangeable + var advisor = MessageChatMemoryAdvisor.builder(chatMemory).conversationIdSupplier(conversationIdHolder::get).build(); + + ChatClient chatClient = ChatClient.builder(this.chatModel).defaultAdvisors(advisor).build(); + + String firstQuestion = "What is the capital of Germany?"; + String firstAnswer = chatClient.prompt() + .user(firstQuestion) + .call() + .content(); + logger.info("First question: {}", firstQuestion); + logger.info("First answer: {}", firstAnswer); + // Assert response is relevant + assertThat(firstAnswer).containsIgnoringCase("Berlin"); + + conversationIdHolder.set(secondConversationId); + String secondQuestion = "What is the capital of Poland?"; + String secondAnswer = chatClient.prompt() + .user(secondQuestion) + .call() + .content(); + logger.info("Second question: {}", secondQuestion); + logger.info("Second answer: {}", secondAnswer); + // Assert response is relevant + assertThat(secondAnswer).containsIgnoringCase("Warsaw"); + + conversationIdHolder.set(firstConversationId); + String thirdQuestion = "What is the capital of Spain?"; + String thirdAnswer = chatClient.prompt() + .user(thirdQuestion) + .call() + .content(); + logger.info("Third question: {}", thirdQuestion); + logger.info("Third answer: {}", thirdAnswer); + // Assert response is relevant + assertThat(thirdAnswer).containsIgnoringCase("Madrid"); + + // Verify first conversation memory contains the firstQuestion, firstAnswer, thirdQuestion and thirdAnswer + List firstMemoryMessages = chatMemory.get(firstConversationId); + assertThat(firstMemoryMessages).hasSize(4); + assertThat(firstMemoryMessages.get(0).getText()).isEqualTo(firstQuestion); + assertThat(firstMemoryMessages.get(2).getText()).isEqualTo(thirdQuestion); + + // Verify second conversation memory contains the secondQuestion and secondAnswer + List secondMemoryMessages = chatMemory.get(secondConversationId); + assertThat(secondMemoryMessages).hasSize(2); + assertThat(secondMemoryMessages.get(0).getText()).isEqualTo(secondQuestion); + } + @Test void shouldHandleNonExistentConversation() { testHandleNonExistentConversation(); @@ -157,7 +222,6 @@ void shouldStoreCompleteContentInStreamingMode() { String userInput = "Tell me a short joke about programming"; // Collect the streaming responses - List streamedResponses = new ArrayList<>(); chatClient.prompt() .user(userInput) .advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId)) diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java index 1b8bbea84e9..82741b875cd 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java @@ -18,6 +18,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.function.Supplier; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -42,22 +43,23 @@ * @author Christian Tzolov * @author Mark Pollack * @author Thomas Vitale + * @author Tomasz Forys * @since 1.0.0 */ public final class MessageChatMemoryAdvisor implements BaseChatMemoryAdvisor { private final ChatMemory chatMemory; - private final String defaultConversationId; + private final Supplier defaultConversationId; private final int order; private final Scheduler scheduler; - private MessageChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int order, - Scheduler scheduler) { + private MessageChatMemoryAdvisor(ChatMemory chatMemory, Supplier defaultConversationId, int order, + Scheduler scheduler) { Assert.notNull(chatMemory, "chatMemory cannot be null"); - Assert.hasText(defaultConversationId, "defaultConversationId cannot be null or empty"); + Assert.hasText(defaultConversationId.get(), "defaultConversationId cannot be null or empty"); Assert.notNull(scheduler, "scheduler cannot be null"); this.chatMemory = chatMemory; this.defaultConversationId = defaultConversationId; @@ -77,7 +79,7 @@ public Scheduler getScheduler() { @Override public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) { - String conversationId = getConversationId(chatClientRequest.context(), this.defaultConversationId); + String conversationId = getConversationId(chatClientRequest.context(), this.defaultConversationId.get()); // 1. Retrieve the chat memory for the current conversation. List memoryMessages = this.chatMemory.get(conversationId); @@ -108,7 +110,7 @@ public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorCh .map(g -> (Message) g.getOutput()) .toList(); } - this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId), + this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId.get()), assistantMessages); return chatClientResponse; } @@ -134,7 +136,7 @@ public static Builder builder(ChatMemory chatMemory) { public static final class Builder { - private String conversationId = ChatMemory.DEFAULT_CONVERSATION_ID; + private Supplier conversationIdSupplier = () -> ChatMemory.DEFAULT_CONVERSATION_ID; private int order = Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER; @@ -152,7 +154,17 @@ private Builder(ChatMemory chatMemory) { * @return the builder */ public Builder conversationId(String conversationId) { - this.conversationId = conversationId; + this.conversationIdSupplier = () -> conversationId; + return this; + } + + /** + * Set the conversation id supplier. + * @param conversationIdSupplier the conversation id supplier + * @return the builder + */ + public Builder conversationIdSupplier(Supplier conversationIdSupplier) { + this.conversationIdSupplier = conversationIdSupplier; return this; } @@ -176,7 +188,7 @@ public Builder scheduler(Scheduler scheduler) { * @return the advisor */ public MessageChatMemoryAdvisor build() { - return new MessageChatMemoryAdvisor(this.chatMemory, this.conversationId, this.order, this.scheduler); + return new MessageChatMemoryAdvisor(this.chatMemory, this.conversationIdSupplier, this.order, this.scheduler); } }