Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;

import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -133,6 +134,70 @@ void shouldHandleMultipleUserMessagesInPrompt() {
assertThat(followUpAnswer).containsIgnoringCase("David");
}

/**
* Tests that the advisor correctly uses a conversation ID supplier when provided.
*/
@Test
protected void testUseSupplierConversationId() {
// Arrange
ChatMemory chatMemory = MessageWindowChatMemory.builder()
.chatMemoryRepository(new InMemoryChatMemoryRepository())
.build();

// ConversationId circular iterator
String firstConversationId = "conversationId-1";
String secondConversationId = "conversationId-2";
AtomicReference<String> conversationIdHolder = new AtomicReference<>(firstConversationId);

// Create advisor with conversation id supplier returning conversationId interchangeable
var advisor = MessageChatMemoryAdvisor.builder(chatMemory).conversationIdSupplier(conversationIdHolder::get).build();

ChatClient chatClient = ChatClient.builder(this.chatModel).defaultAdvisors(advisor).build();

String firstQuestion = "What is the capital of Germany?";
String firstAnswer = chatClient.prompt()
.user(firstQuestion)
.call()
.content();
logger.info("First question: {}", firstQuestion);
logger.info("First answer: {}", firstAnswer);
// Assert response is relevant
assertThat(firstAnswer).containsIgnoringCase("Berlin");

conversationIdHolder.set(secondConversationId);
String secondQuestion = "What is the capital of Poland?";
String secondAnswer = chatClient.prompt()
.user(secondQuestion)
.call()
.content();
logger.info("Second question: {}", secondQuestion);
logger.info("Second answer: {}", secondAnswer);
// Assert response is relevant
assertThat(secondAnswer).containsIgnoringCase("Warsaw");

conversationIdHolder.set(firstConversationId);
String thirdQuestion = "What is the capital of Spain?";
String thirdAnswer = chatClient.prompt()
.user(thirdQuestion)
.call()
.content();
logger.info("Third question: {}", thirdQuestion);
logger.info("Third answer: {}", thirdAnswer);
// Assert response is relevant
assertThat(thirdAnswer).containsIgnoringCase("Madrid");

// Verify first conversation memory contains the firstQuestion, firstAnswer, thirdQuestion and thirdAnswer
List<Message> firstMemoryMessages = chatMemory.get(firstConversationId);
assertThat(firstMemoryMessages).hasSize(4);
assertThat(firstMemoryMessages.get(0).getText()).isEqualTo(firstQuestion);
assertThat(firstMemoryMessages.get(2).getText()).isEqualTo(thirdQuestion);

// Verify second conversation memory contains the secondQuestion and secondAnswer
List<Message> secondMemoryMessages = chatMemory.get(secondConversationId);
assertThat(secondMemoryMessages).hasSize(2);
assertThat(secondMemoryMessages.get(0).getText()).isEqualTo(secondQuestion);
}

@Test
void shouldHandleNonExistentConversation() {
testHandleNonExistentConversation();
Expand All @@ -157,7 +222,6 @@ void shouldStoreCompleteContentInStreamingMode() {
String userInput = "Tell me a short joke about programming";

// Collect the streaming responses
List<String> streamedResponses = new ArrayList<>();
chatClient.prompt()
.user(userInput)
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;

import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
Expand All @@ -42,22 +43,23 @@
* @author Christian Tzolov
* @author Mark Pollack
* @author Thomas Vitale
* @author Tomasz Forys
* @since 1.0.0
*/
public final class MessageChatMemoryAdvisor implements BaseChatMemoryAdvisor {

private final ChatMemory chatMemory;

private final String defaultConversationId;
private final Supplier<String> defaultConversationId;

private final int order;

private final Scheduler scheduler;

private MessageChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int order,
Scheduler scheduler) {
private MessageChatMemoryAdvisor(ChatMemory chatMemory, Supplier<String> defaultConversationId, int order,
Scheduler scheduler) {
Assert.notNull(chatMemory, "chatMemory cannot be null");
Assert.hasText(defaultConversationId, "defaultConversationId cannot be null or empty");
Assert.hasText(defaultConversationId.get(), "defaultConversationId cannot be null or empty");
Assert.notNull(scheduler, "scheduler cannot be null");
this.chatMemory = chatMemory;
this.defaultConversationId = defaultConversationId;
Expand All @@ -77,7 +79,7 @@ public Scheduler getScheduler() {

@Override
public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) {
String conversationId = getConversationId(chatClientRequest.context(), this.defaultConversationId);
String conversationId = getConversationId(chatClientRequest.context(), this.defaultConversationId.get());

// 1. Retrieve the chat memory for the current conversation.
List<Message> memoryMessages = this.chatMemory.get(conversationId);
Expand Down Expand Up @@ -108,7 +110,7 @@ public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorCh
.map(g -> (Message) g.getOutput())
.toList();
}
this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId),
this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId.get()),
assistantMessages);
return chatClientResponse;
}
Expand All @@ -134,7 +136,7 @@ public static Builder builder(ChatMemory chatMemory) {

public static final class Builder {

private String conversationId = ChatMemory.DEFAULT_CONVERSATION_ID;
private Supplier<String> conversationIdSupplier = () -> ChatMemory.DEFAULT_CONVERSATION_ID;

private int order = Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER;

Expand All @@ -152,7 +154,17 @@ private Builder(ChatMemory chatMemory) {
* @return the builder
*/
public Builder conversationId(String conversationId) {
this.conversationId = conversationId;
this.conversationIdSupplier = () -> conversationId;
return this;
}

/**
* Set the conversation id supplier.
* @param conversationIdSupplier the conversation id supplier
* @return the builder
*/
public Builder conversationIdSupplier(Supplier<String> conversationIdSupplier) {
this.conversationIdSupplier = conversationIdSupplier;
return this;
}

Expand All @@ -176,7 +188,7 @@ public Builder scheduler(Scheduler scheduler) {
* @return the advisor
*/
public MessageChatMemoryAdvisor build() {
return new MessageChatMemoryAdvisor(this.chatMemory, this.conversationId, this.order, this.scheduler);
return new MessageChatMemoryAdvisor(this.chatMemory, this.conversationIdSupplier, this.order, this.scheduler);
}

}
Expand Down