1919import static org .assertj .core .api .Assertions .assertThat ;
2020import static org .mockito .Mockito .when ;
2121
22+ import java .time .Duration ;
2223import java .util .List ;
24+ import java .util .Map ;
2325
2426import org .junit .jupiter .api .Test ;
2527import org .junit .jupiter .api .extension .ExtendWith ;
2830import org .mockito .Mock ;
2931import org .mockito .junit .jupiter .MockitoExtension ;
3032import org .springframework .ai .chat .client .ChatClient ;
33+ import org .springframework .ai .chat .messages .AssistantMessage ;
3134import org .springframework .ai .chat .messages .Message ;
3235import org .springframework .ai .chat .messages .MessageType ;
36+ import org .springframework .ai .chat .metadata .ChatResponseMetadata ;
37+ import org .springframework .ai .chat .metadata .DefaultUsage ;
38+ import org .springframework .ai .chat .metadata .RateLimit ;
3339import org .springframework .ai .chat .model .ChatModel ;
3440import org .springframework .ai .chat .model .ChatResponse ;
3541import org .springframework .ai .chat .model .Generation ;
@@ -60,8 +66,50 @@ public class QuestionAnswerAdvisorTests {
6066 @ Test
6167 public void qaAdvisorWithDynamicFilterExpressions () {
6268
69+ // @formatter:off
6370 when (chatModel .call (promptCaptor .capture ()))
64- .thenReturn (new ChatResponse (List .of (new Generation ("Your answer is ZXY" ))));
71+ .thenReturn (new ChatResponse (List .of (new Generation (new AssistantMessage ("Your answer is ZXY" ))),
72+ ChatResponseMetadata .builder ()
73+ .withId ("678" )
74+ .withModel ("model1" )
75+ .withKeyValue ("key6" , "value6" )
76+ .withMetadata (Map .of ("key1" ,"value1" ))
77+ .withPromptMetadata (null )
78+ .withRateLimit (new RateLimit () {
79+
80+ @ Override
81+ public Long getRequestsLimit () {
82+ return 5l ;
83+ }
84+
85+ @ Override
86+ public Long getRequestsRemaining () {
87+ return 6l ;
88+ }
89+
90+ @ Override
91+ public Duration getRequestsReset () {
92+ return Duration .ofSeconds (7 );
93+ }
94+
95+ @ Override
96+ public Long getTokensLimit () {
97+ return 8l ;
98+ }
99+
100+ @ Override
101+ public Long getTokensRemaining () {
102+ return 8l ;
103+ }
104+
105+ @ Override
106+ public Duration getTokensReset () {
107+ return Duration .ofSeconds (9 );
108+ }
109+ })
110+ .withUsage (new DefaultUsage (6l , 7l ))
111+ .build ()));
112+ // @formatter:on
65113
66114 when (vectorStore .similaritySearch (vectorSearchCaptor .capture ()))
67115 .thenReturn (List .of (new Document ("doc1" ), new Document ("doc2" )));
@@ -75,13 +123,33 @@ public void qaAdvisorWithDynamicFilterExpressions() {
75123 .build ();
76124
77125 // @formatter:off
78- var content = chatClient .prompt ()
126+ var response = chatClient .prompt ()
79127 .user ("Please answer my question XYZ" )
80128 .advisors (a -> a .param (QuestionAnswerAdvisor .FILTER_EXPRESSION , "type == 'Spring'" ))
81129 .call ()
82- .content ();
130+ .chatResponse ();
83131 //formatter:on
84132
133+ // Ensure the metadata is correctly copied over
134+ assertThat (response .getMetadata ().getModel ()).isEqualTo ("model1" );
135+ assertThat (response .getMetadata ().getId ()).isEqualTo ("678" );
136+ assertThat (response .getMetadata ().getRateLimit ().getRequestsLimit ()).isEqualTo (5l );
137+ assertThat (response .getMetadata ().getRateLimit ().getRequestsRemaining ()).isEqualTo (6l );
138+ assertThat (response .getMetadata ().getRateLimit ().getRequestsReset ()).isEqualTo (Duration .ofSeconds (7 ));
139+ assertThat (response .getMetadata ().getRateLimit ().getTokensLimit ()).isEqualTo (8l );
140+ assertThat (response .getMetadata ().getRateLimit ().getTokensRemaining ()).isEqualTo (8l );
141+ assertThat (response .getMetadata ().getRateLimit ().getTokensReset ()).isEqualTo (Duration .ofSeconds (9 ));
142+ assertThat (response .getMetadata ().getUsage ().getPromptTokens ()).isEqualTo (6l );
143+ assertThat (response .getMetadata ().getUsage ().getGenerationTokens ()).isEqualTo (7l );
144+ assertThat (response .getMetadata ().getUsage ().getTotalTokens ()).isEqualTo (6l + 7l );
145+ assertThat (response .getMetadata ().get ("key6" ).toString ()).isEqualTo ("value6" );
146+ assertThat (response .getMetadata ().get ("key1" ).toString ()).isEqualTo ("value1" );
147+
148+
149+
150+
151+ String content = response .getResult ().getOutput ().getContent ();
152+
85153 assertThat (content ).isEqualTo ("Your answer is ZXY" );
86154
87155 Message systemMessage = promptCaptor .getValue ().getInstructions ().get (0 );
@@ -112,5 +180,7 @@ public void qaAdvisorWithDynamicFilterExpressions() {
112180 assertThat (vectorSearchCaptor .getValue ().getFilterExpression ()).isEqualTo (new FilterExpressionBuilder ().eq ("type" , "Spring" ).build ());
113181 assertThat (vectorSearchCaptor .getValue ().getSimilarityThreshold ()).isEqualTo (0.99d );
114182 assertThat (vectorSearchCaptor .getValue ().getTopK ()).isEqualTo (6 );
183+
184+
115185 }
116186}
0 commit comments