Skip to content

Commit 3d252ca

Browse files
committed
Copy metadata in ChatResponse Builder's from() method
- Update ChatResponse.Builder to copy all metadata fields when using from() - Expand test case to verify correct metadata copying in QA advisor Resolves #1537
1 parent a39aadc commit 3d252ca

File tree

2 files changed

+78
-3
lines changed

2 files changed

+78
-3
lines changed

spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatResponse.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ private Builder() {
125125

126126
public Builder from(ChatResponse other) {
127127
this.generations = other.generations;
128+
this.chatResponseMetadataBuilder.withModel(other.chatResponseMetadata.getModel());
129+
this.chatResponseMetadataBuilder.withId(other.chatResponseMetadata.getId());
130+
this.chatResponseMetadataBuilder.withRateLimit(other.chatResponseMetadata.getRateLimit());
131+
this.chatResponseMetadataBuilder.withUsage(other.chatResponseMetadata.getUsage());
132+
this.chatResponseMetadataBuilder.withPromptMetadata(other.chatResponseMetadata.getPromptMetadata());
128133
Set<Map.Entry<String, Object>> entries = other.chatResponseMetadata.entrySet();
129134
for (Map.Entry<String, Object> entry : entries) {
130135
this.chatResponseMetadataBuilder.withKeyValue(entry.getKey(), entry.getValue());

spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisorTests.java

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
import static org.assertj.core.api.Assertions.assertThat;
2020
import static org.mockito.Mockito.when;
2121

22+
import java.time.Duration;
2223
import java.util.List;
24+
import java.util.Map;
2325

2426
import org.junit.jupiter.api.Test;
2527
import org.junit.jupiter.api.extension.ExtendWith;
@@ -28,8 +30,12 @@
2830
import org.mockito.Mock;
2931
import org.mockito.junit.jupiter.MockitoExtension;
3032
import org.springframework.ai.chat.client.ChatClient;
33+
import org.springframework.ai.chat.messages.AssistantMessage;
3134
import org.springframework.ai.chat.messages.Message;
3235
import org.springframework.ai.chat.messages.MessageType;
36+
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
37+
import org.springframework.ai.chat.metadata.DefaultUsage;
38+
import org.springframework.ai.chat.metadata.RateLimit;
3339
import org.springframework.ai.chat.model.ChatModel;
3440
import org.springframework.ai.chat.model.ChatResponse;
3541
import org.springframework.ai.chat.model.Generation;
@@ -60,8 +66,50 @@ public class QuestionAnswerAdvisorTests {
6066
@Test
6167
public void qaAdvisorWithDynamicFilterExpressions() {
6268

69+
// @formatter:off
6370
when(chatModel.call(promptCaptor.capture()))
64-
.thenReturn(new ChatResponse(List.of(new Generation("Your answer is ZXY"))));
71+
.thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your answer is ZXY"))),
72+
ChatResponseMetadata.builder()
73+
.withId("678")
74+
.withModel("model1")
75+
.withKeyValue("key6", "value6")
76+
.withMetadata(Map.of("key1","value1" ))
77+
.withPromptMetadata(null)
78+
.withRateLimit(new RateLimit() {
79+
80+
@Override
81+
public Long getRequestsLimit() {
82+
return 5l;
83+
}
84+
85+
@Override
86+
public Long getRequestsRemaining() {
87+
return 6l;
88+
}
89+
90+
@Override
91+
public Duration getRequestsReset() {
92+
return Duration.ofSeconds(7);
93+
}
94+
95+
@Override
96+
public Long getTokensLimit() {
97+
return 8l;
98+
}
99+
100+
@Override
101+
public Long getTokensRemaining() {
102+
return 8l;
103+
}
104+
105+
@Override
106+
public Duration getTokensReset() {
107+
return Duration.ofSeconds(9);
108+
}
109+
})
110+
.withUsage(new DefaultUsage(6l, 7l))
111+
.build()));
112+
// @formatter:on
65113

66114
when(vectorStore.similaritySearch(vectorSearchCaptor.capture()))
67115
.thenReturn(List.of(new Document("doc1"), new Document("doc2")));
@@ -75,13 +123,33 @@ public void qaAdvisorWithDynamicFilterExpressions() {
75123
.build();
76124

77125
// @formatter:off
78-
var content = chatClient.prompt()
126+
var response = chatClient.prompt()
79127
.user("Please answer my question XYZ")
80128
.advisors(a -> a.param(QuestionAnswerAdvisor.FILTER_EXPRESSION, "type == 'Spring'"))
81129
.call()
82-
.content();
130+
.chatResponse();
83131
//formatter:on
84132

133+
// Ensure the metadata is correctly copied over
134+
assertThat(response.getMetadata().getModel()).isEqualTo("model1");
135+
assertThat(response.getMetadata().getId()).isEqualTo("678");
136+
assertThat(response.getMetadata().getRateLimit().getRequestsLimit()).isEqualTo(5l);
137+
assertThat(response.getMetadata().getRateLimit().getRequestsRemaining()).isEqualTo(6l);
138+
assertThat(response.getMetadata().getRateLimit().getRequestsReset()).isEqualTo(Duration.ofSeconds(7));
139+
assertThat(response.getMetadata().getRateLimit().getTokensLimit()).isEqualTo(8l);
140+
assertThat(response.getMetadata().getRateLimit().getTokensRemaining()).isEqualTo(8l);
141+
assertThat(response.getMetadata().getRateLimit().getTokensReset()).isEqualTo(Duration.ofSeconds(9));
142+
assertThat(response.getMetadata().getUsage().getPromptTokens()).isEqualTo(6l);
143+
assertThat(response.getMetadata().getUsage().getGenerationTokens()).isEqualTo(7l);
144+
assertThat(response.getMetadata().getUsage().getTotalTokens()).isEqualTo(6l + 7l);
145+
assertThat(response.getMetadata().get("key6").toString()).isEqualTo("value6");
146+
assertThat(response.getMetadata().get("key1").toString()).isEqualTo("value1");
147+
148+
149+
150+
151+
String content = response.getResult().getOutput().getContent();
152+
85153
assertThat(content).isEqualTo("Your answer is ZXY");
86154

87155
Message systemMessage = promptCaptor.getValue().getInstructions().get(0);
@@ -112,5 +180,7 @@ public void qaAdvisorWithDynamicFilterExpressions() {
112180
assertThat(vectorSearchCaptor.getValue().getFilterExpression()).isEqualTo(new FilterExpressionBuilder().eq("type", "Spring").build());
113181
assertThat(vectorSearchCaptor.getValue().getSimilarityThreshold()).isEqualTo(0.99d);
114182
assertThat(vectorSearchCaptor.getValue().getTopK()).isEqualTo(6);
183+
184+
115185
}
116186
}

0 commit comments

Comments
 (0)