Skip to content

Commit 2c8c4e7

Browse files
committed
Make ToolCallAdvisor extensible with hook methods (#5004)
- Add protected doInitializeLoop, doBeforeCall, and doAfterCall hooks to allow subclasses to customize the tool calling loop behavior. - Update Builder to support inheritance via self-referential generics. Signed-off-by: Christian Tzolov <christian.tzolov@broadcom.com>
1 parent 1fa0e07 commit 2c8c4e7

File tree

2 files changed

+213
-12
lines changed

2 files changed

+213
-12
lines changed

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ToolCallAdvisor.java

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
*
4545
* @author Christian Tzolov
4646
*/
47-
public final class ToolCallAdvisor implements CallAdvisor, StreamAdvisor {
47+
public class ToolCallAdvisor implements CallAdvisor, StreamAdvisor {
4848

4949
private final ToolCallingManager toolCallingManager;
5050

@@ -57,7 +57,7 @@ public final class ToolCallAdvisor implements CallAdvisor, StreamAdvisor {
5757
*/
5858
private final int advisorOrder;
5959

60-
private ToolCallAdvisor(ToolCallingManager toolCallingManager, int advisorOrder) {
60+
protected ToolCallAdvisor(ToolCallingManager toolCallingManager, int advisorOrder) {
6161
Assert.notNull(toolCallingManager, "toolCallingManager must not be null");
6262
Assert.isTrue(advisorOrder > BaseAdvisor.HIGHEST_PRECEDENCE && advisorOrder < BaseAdvisor.LOWEST_PRECEDENCE,
6363
"advisorOrder must be between HIGHEST_PRECEDENCE and LOWEST_PRECEDENCE");
@@ -76,7 +76,6 @@ public int getOrder() {
7676
return this.advisorOrder;
7777
}
7878

79-
@SuppressWarnings("null")
8079
@Override
8180
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
8281
Assert.notNull(callAdvisorChain, "callAdvisorChain must not be null");
@@ -88,6 +87,8 @@ public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAd
8887
"ToolCall Advisor requires ToolCallingChatOptions to be set in the ChatClientRequest options.");
8988
}
9089

90+
chatClientRequest = this.doInitializeLoop(chatClientRequest, callAdvisorChain);
91+
9192
// Overwrite the ToolCallingChatOptions to disable internal tool execution.
9293
var optionsCopy = (ToolCallingChatOptions) chatClientRequest.prompt().getOptions().copy();
9394

@@ -109,8 +110,12 @@ public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAd
109110
.build();
110111

111112
// Next Call
113+
processedChatClientRequest = this.doBeforeCall(processedChatClientRequest, callAdvisorChain);
114+
112115
chatClientResponse = callAdvisorChain.copy(this).nextCall(processedChatClientRequest);
113116

117+
chatClientResponse = this.doAfterCall(chatClientResponse, callAdvisorChain);
118+
114119
// After Call
115120

116121
// TODO: check that this is tool call is sufficiant for all chat models
@@ -148,6 +153,19 @@ public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAd
148153
return chatClientResponse;
149154
}
150155

156+
protected ChatClientRequest doInitializeLoop(ChatClientRequest chatClientRequest,
157+
CallAdvisorChain callAdvisorChain) {
158+
return chatClientRequest;
159+
}
160+
161+
protected ChatClientRequest doBeforeCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
162+
return chatClientRequest;
163+
}
164+
165+
protected ChatClientResponse doAfterCall(ChatClientResponse chatClientResponse, CallAdvisorChain callAdvisorChain) {
166+
return chatClientResponse;
167+
}
168+
151169
@Override
152170
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
153171
StreamAdvisorChain streamAdvisorChain) {
@@ -158,30 +176,45 @@ public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest
158176
* Creates a new Builder instance for constructing a ToolCallAdvisor.
159177
* @return a new Builder instance
160178
*/
161-
public static Builder builder() {
162-
return new Builder();
179+
public static Builder<?> builder() {
180+
return new Builder<>();
163181
}
164182

165183
/**
166184
* Builder for creating instances of ToolCallAdvisor.
185+
* <p>
186+
* This builder uses the self-referential generic pattern to support extensibility.
187+
*
188+
* @param <T> the builder type, used for self-referential generics to support method
189+
* chaining in subclasses
167190
*/
168-
public final static class Builder {
191+
public static class Builder<T extends Builder<T>> {
169192

170193
private ToolCallingManager toolCallingManager = ToolCallingManager.builder().build();
171194

172195
private int advisorOrder = BaseAdvisor.HIGHEST_PRECEDENCE + 300;
173196

174-
private Builder() {
197+
protected Builder() {
198+
}
199+
200+
/**
201+
* Returns this builder cast to the appropriate type for method chaining.
202+
* Subclasses should override this method to return the correct type.
203+
* @return this builder instance
204+
*/
205+
@SuppressWarnings("unchecked")
206+
protected T self() {
207+
return (T) this;
175208
}
176209

177210
/**
178211
* Sets the ToolCallingManager to be used by the advisor.
179212
* @param toolCallingManager the ToolCallingManager instance
180213
* @return this Builder instance for method chaining
181214
*/
182-
public Builder toolCallingManager(ToolCallingManager toolCallingManager) {
215+
public T toolCallingManager(ToolCallingManager toolCallingManager) {
183216
this.toolCallingManager = toolCallingManager;
184-
return this;
217+
return self();
185218
}
186219

187220
/**
@@ -190,9 +223,25 @@ public Builder toolCallingManager(ToolCallingManager toolCallingManager) {
190223
* LOWEST_PRECEDENCE
191224
* @return this Builder instance for method chaining
192225
*/
193-
public Builder advisorOrder(int advisorOrder) {
226+
public T advisorOrder(int advisorOrder) {
194227
this.advisorOrder = advisorOrder;
195-
return this;
228+
return self();
229+
}
230+
231+
/**
232+
* Returns the configured ToolCallingManager.
233+
* @return the ToolCallingManager instance
234+
*/
235+
protected ToolCallingManager getToolCallingManager() {
236+
return this.toolCallingManager;
237+
}
238+
239+
/**
240+
* Returns the configured advisor order.
241+
* @return the advisor order value
242+
*/
243+
protected int getAdvisorOrder() {
244+
return this.advisorOrder;
196245
}
197246

198247
/**

spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/ToolCallAdvisorTests.java

Lines changed: 153 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,99 @@ void testGetOrder() {
377377
assertThat(advisor.getOrder()).isEqualTo(customOrder);
378378
}
379379

380+
@Test
381+
void testBuilderGetters() {
382+
ToolCallingManager customManager = mock(ToolCallingManager.class);
383+
int customOrder = BaseAdvisor.HIGHEST_PRECEDENCE + 500;
384+
385+
ToolCallAdvisor.Builder<?> builder = ToolCallAdvisor.builder()
386+
.toolCallingManager(customManager)
387+
.advisorOrder(customOrder);
388+
389+
assertThat(builder.getToolCallingManager()).isEqualTo(customManager);
390+
assertThat(builder.getAdvisorOrder()).isEqualTo(customOrder);
391+
}
392+
393+
@Test
394+
void testExtendedAdvisorWithCustomHooks() {
395+
int[] hookCallCounts = { 0, 0, 0 }; // initializeLoop, beforeCall, afterCall
396+
397+
// Create extended advisor to verify hooks are called
398+
TestableToolCallAdvisor advisor = new TestableToolCallAdvisor(this.toolCallingManager,
399+
BaseAdvisor.HIGHEST_PRECEDENCE + 300, hookCallCounts);
400+
401+
ChatClientRequest request = createMockRequest(true);
402+
ChatClientResponse response = createMockResponse(false);
403+
404+
CallAdvisor terminalAdvisor = new TerminalCallAdvisor((req, chain) -> response);
405+
406+
CallAdvisorChain realChain = DefaultAroundAdvisorChain.builder(ObservationRegistry.NOOP)
407+
.pushAll(List.of(advisor, terminalAdvisor))
408+
.build();
409+
410+
advisor.adviseCall(request, realChain);
411+
412+
// Verify hooks were called
413+
assertThat(hookCallCounts[0]).isEqualTo(1); // doInitializeLoop called once
414+
assertThat(hookCallCounts[1]).isEqualTo(1); // doBeforeCall called once
415+
assertThat(hookCallCounts[2]).isEqualTo(1); // doAfterCall called once
416+
}
417+
418+
@Test
419+
void testExtendedAdvisorHooksCalledMultipleTimesWithToolCalls() {
420+
int[] hookCallCounts = { 0, 0, 0 }; // initializeLoop, beforeCall, afterCall
421+
422+
TestableToolCallAdvisor advisor = new TestableToolCallAdvisor(this.toolCallingManager,
423+
BaseAdvisor.HIGHEST_PRECEDENCE + 300, hookCallCounts);
424+
425+
ChatClientRequest request = createMockRequest(true);
426+
ChatClientResponse responseWithToolCall = createMockResponse(true);
427+
ChatClientResponse finalResponse = createMockResponse(false);
428+
429+
int[] callCount = { 0 };
430+
CallAdvisor terminalAdvisor = new TerminalCallAdvisor((req, chain) -> {
431+
callCount[0]++;
432+
return callCount[0] == 1 ? responseWithToolCall : finalResponse;
433+
});
434+
435+
CallAdvisorChain realChain = DefaultAroundAdvisorChain.builder(ObservationRegistry.NOOP)
436+
.pushAll(List.of(advisor, terminalAdvisor))
437+
.build();
438+
439+
// Mock tool execution result
440+
List<Message> conversationHistory = List.of(new UserMessage("test"),
441+
AssistantMessage.builder().content("").build(), ToolResponseMessage.builder().build());
442+
ToolExecutionResult toolExecutionResult = ToolExecutionResult.builder()
443+
.conversationHistory(conversationHistory)
444+
.build();
445+
when(this.toolCallingManager.executeToolCalls(any(Prompt.class), any(ChatResponse.class)))
446+
.thenReturn(toolExecutionResult);
447+
448+
advisor.adviseCall(request, realChain);
449+
450+
// Verify hooks were called correct number of times
451+
assertThat(hookCallCounts[0]).isEqualTo(1); // doInitializeLoop called once
452+
// (before loop)
453+
assertThat(hookCallCounts[1]).isEqualTo(2); // doBeforeCall called twice (each
454+
// iteration)
455+
assertThat(hookCallCounts[2]).isEqualTo(2); // doAfterCall called twice (each
456+
// iteration)
457+
}
458+
459+
@Test
460+
void testExtendedBuilderWithCustomBuilder() {
461+
ToolCallingManager customManager = mock(ToolCallingManager.class);
462+
int customOrder = BaseAdvisor.HIGHEST_PRECEDENCE + 450;
463+
464+
TestableToolCallAdvisor advisor = TestableToolCallAdvisor.testBuilder()
465+
.toolCallingManager(customManager)
466+
.advisorOrder(customOrder)
467+
.build();
468+
469+
assertThat(advisor).isNotNull();
470+
assertThat(advisor.getOrder()).isEqualTo(customOrder);
471+
}
472+
380473
// Helper methods
381474

382475
private ChatClientRequest createMockRequest(boolean withToolCallingOptions) {
@@ -472,6 +565,65 @@ public ChatClientResponse adviseCall(ChatClientRequest req, CallAdvisorChain cha
472565
return this.responseFunction.apply(req, chain);
473566
}
474567

475-
};
568+
}
569+
570+
/**
571+
* Test subclass of ToolCallAdvisor to verify extensibility and hook methods.
572+
*/
573+
private static class TestableToolCallAdvisor extends ToolCallAdvisor {
574+
575+
private final int[] hookCallCounts;
576+
577+
TestableToolCallAdvisor(ToolCallingManager toolCallingManager, int advisorOrder, int[] hookCallCounts) {
578+
super(toolCallingManager, advisorOrder);
579+
this.hookCallCounts = hookCallCounts;
580+
}
581+
582+
@Override
583+
protected ChatClientRequest doInitializeLoop(ChatClientRequest chatClientRequest,
584+
CallAdvisorChain callAdvisorChain) {
585+
if (this.hookCallCounts != null) {
586+
this.hookCallCounts[0]++;
587+
}
588+
return super.doInitializeLoop(chatClientRequest, callAdvisorChain);
589+
}
590+
591+
@Override
592+
protected ChatClientRequest doBeforeCall(ChatClientRequest chatClientRequest,
593+
CallAdvisorChain callAdvisorChain) {
594+
if (this.hookCallCounts != null) {
595+
this.hookCallCounts[1]++;
596+
}
597+
return super.doBeforeCall(chatClientRequest, callAdvisorChain);
598+
}
599+
600+
@Override
601+
protected ChatClientResponse doAfterCall(ChatClientResponse chatClientResponse,
602+
CallAdvisorChain callAdvisorChain) {
603+
if (this.hookCallCounts != null) {
604+
this.hookCallCounts[2]++;
605+
}
606+
return super.doAfterCall(chatClientResponse, callAdvisorChain);
607+
}
608+
609+
static TestableBuilder testBuilder() {
610+
return new TestableBuilder();
611+
}
612+
613+
static class TestableBuilder extends ToolCallAdvisor.Builder<TestableBuilder> {
614+
615+
@Override
616+
protected TestableBuilder self() {
617+
return this;
618+
}
619+
620+
@Override
621+
public TestableToolCallAdvisor build() {
622+
return new TestableToolCallAdvisor(getToolCallingManager(), getAdvisorOrder(), null);
623+
}
624+
625+
}
626+
627+
}
476628

477629
}

0 commit comments

Comments
 (0)