Skip to content

Commit 0646d1e

Browse files
CorgiBoyGmarkpollack
authored andcommitted
Fix extraBody not being included in OpenAI API requests
Two issues prevented extraBody from reaching the HTTP request: 1. Options→Request merge: ChatCompletionRequest.extraBody lacked @JsonProperty annotation, causing ModelOptionsUtils.merge() to filter it out. Added @JsonProperty("extra_body") annotation. 2. Options→Options merge: ModelOptionsUtils.merge() replaces the entire extraBody map instead of merging keys, causing default values to be lost when runtime options also specify extraBody. Added explicit mergeExtraBody() method to properly combine maps with runtime values taking precedence. Added wire-level and unit tests to verify end-to-end behavior.
1 parent 2c854d3 commit 0646d1e

File tree

5 files changed

+336
-1
lines changed

5 files changed

+336
-1
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,13 +548,16 @@ Prompt buildRequestPrompt(Prompt prompt) {
548548
this.defaultOptions.getToolCallbacks()));
549549
requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(),
550550
this.defaultOptions.getToolContext()));
551+
requestOptions
552+
.setExtraBody(mergeExtraBody(runtimeOptions.getExtraBody(), this.defaultOptions.getExtraBody()));
551553
}
552554
else {
553555
requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders());
554556
requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled());
555557
requestOptions.setToolNames(this.defaultOptions.getToolNames());
556558
requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());
557559
requestOptions.setToolContext(this.defaultOptions.getToolContext());
560+
requestOptions.setExtraBody(this.defaultOptions.getExtraBody());
558561
}
559562

560563
ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks());
@@ -569,6 +572,21 @@ private Map<String, String> mergeHttpHeaders(Map<String, String> runtimeHttpHead
569572
return mergedHttpHeaders;
570573
}
571574

575+
private Map<String, Object> mergeExtraBody(Map<String, Object> runtimeExtraBody,
576+
Map<String, Object> defaultExtraBody) {
577+
if (defaultExtraBody == null && runtimeExtraBody == null) {
578+
return null;
579+
}
580+
var merged = new HashMap<String, Object>();
581+
if (defaultExtraBody != null) {
582+
merged.putAll(defaultExtraBody);
583+
}
584+
if (runtimeExtraBody != null) {
585+
merged.putAll(runtimeExtraBody); // runtime overrides default
586+
}
587+
return merged.isEmpty() ? null : merged;
588+
}
589+
572590
/**
573591
* Accessible for testing.
574592
*/

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1139,7 +1139,7 @@ public record ChatCompletionRequest(// @formatter:off
11391139
@JsonProperty("verbosity") String verbosity,
11401140
@JsonProperty("prompt_cache_key") String promptCacheKey,
11411141
@JsonProperty("safety_identifier") String safetyIdentifier,
1142-
Map<String, Object> extraBody) {
1142+
@JsonProperty("extra_body") Map<String, Object> extraBody) {
11431143

11441144
/**
11451145
* Compact constructor that ensures extraBody is initialized as a mutable HashMap

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,70 @@ void defaultOptionsTools() {
162162
assertThat(request.tools().get(0).getFunction().getName()).isEqualTo(TOOL_FUNCTION_NAME);
163163
}
164164

165+
@Test
166+
void extraBodyIsMergedIntoRequest() {
167+
var client = OpenAiChatModel.builder()
168+
.openAiApi(OpenAiApi.builder().apiKey("TEST").build())
169+
.defaultOptions(OpenAiChatOptions.builder()
170+
.model("gpt-4")
171+
.extraBody(Map.of("default_key", "default_value", "shared_key", "default"))
172+
.build())
173+
.build();
174+
175+
var prompt = client.buildRequestPrompt(new Prompt("Test",
176+
OpenAiChatOptions.builder()
177+
.extraBody(Map.of("runtime_key", "runtime_value", "shared_key", "runtime"))
178+
.build()));
179+
180+
var request = client.createRequest(prompt, false);
181+
182+
// Verify extraBody is present in the request
183+
assertThat(request.extraBody()).isNotNull();
184+
// Default key should be present
185+
assertThat(request.extraBody()).containsEntry("default_key", "default_value");
186+
// Runtime key should be present
187+
assertThat(request.extraBody()).containsEntry("runtime_key", "runtime_value");
188+
// Runtime should override default for shared key
189+
assertThat(request.extraBody()).containsEntry("shared_key", "runtime");
190+
}
191+
192+
@Test
193+
void extraBodyFromDefaultOptionsOnly() {
194+
var client = OpenAiChatModel.builder()
195+
.openAiApi(OpenAiApi.builder().apiKey("TEST").build())
196+
.defaultOptions(OpenAiChatOptions.builder()
197+
.model("gpt-4")
198+
.extraBody(Map.of("top_k", 50, "repetition_penalty", 1.1))
199+
.build())
200+
.build();
201+
202+
var prompt = client.buildRequestPrompt(new Prompt("Test"));
203+
204+
var request = client.createRequest(prompt, false);
205+
206+
// Verify extraBody from default options is present
207+
assertThat(request.extraBody()).isNotNull();
208+
assertThat(request.extraBody()).containsEntry("top_k", 50);
209+
assertThat(request.extraBody()).containsEntry("repetition_penalty", 1.1);
210+
}
211+
212+
@Test
213+
void extraBodyFromRuntimeOptionsOnly() {
214+
var client = OpenAiChatModel.builder()
215+
.openAiApi(OpenAiApi.builder().apiKey("TEST").build())
216+
.defaultOptions(OpenAiChatOptions.builder().model("gpt-4").build())
217+
.build();
218+
219+
var prompt = client.buildRequestPrompt(
220+
new Prompt("Test", OpenAiChatOptions.builder().extraBody(Map.of("enable_thinking", true)).build()));
221+
222+
var request = client.createRequest(prompt, false);
223+
224+
// Verify extraBody from runtime options is present
225+
assertThat(request.extraBody()).isNotNull();
226+
assertThat(request.extraBody()).containsEntry("enable_thinking", true);
227+
}
228+
165229
static class TestToolCallback implements ToolCallback {
166230

167231
private final ToolDefinition toolDefinition;

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/ExtraBodySerializationTest.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import com.fasterxml.jackson.databind.ObjectMapper;
2323
import org.junit.jupiter.api.Test;
2424

25+
import org.springframework.ai.model.ModelOptionsUtils;
26+
import org.springframework.ai.openai.OpenAiChatOptions;
2527
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
2628

2729
import static org.assertj.core.api.Assertions.assertThat;
@@ -208,4 +210,25 @@ void testDeserializationWithComplexExtraFields() throws Exception {
208210
assertThat(request.extraBody().get("stop_token_ids")).isInstanceOf(List.class);
209211
}
210212

213+
@Test
214+
void testMergeWithExtraBody() throws Exception {
215+
// Arrange: Create OpenAiChatOptions with extraBody
216+
OpenAiChatOptions requestOptions = OpenAiChatOptions.builder()
217+
.model("test-model")
218+
.extraBody(Map.of("enable_thinking", true, "max_depth", 10))
219+
.build();
220+
221+
// Create empty ChatCompletionRequest
222+
ChatCompletionRequest request = new ChatCompletionRequest(null, null);
223+
224+
// Act: Merge options into request
225+
request = ModelOptionsUtils.merge(requestOptions, request, ChatCompletionRequest.class);
226+
227+
// Assert: Verify extraBody was successfully merged
228+
assertThat(request.extraBody()).isNotNull();
229+
assertThat(request.extraBody()).containsEntry("enable_thinking", true);
230+
assertThat(request.extraBody()).containsEntry("max_depth", 10);
231+
assertThat(request.model()).isEqualTo("test-model");
232+
}
233+
211234
}
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai.chat;
18+
19+
import java.util.Map;
20+
21+
import com.fasterxml.jackson.databind.JsonNode;
22+
import com.fasterxml.jackson.databind.ObjectMapper;
23+
import okhttp3.mockwebserver.MockResponse;
24+
import okhttp3.mockwebserver.MockWebServer;
25+
import okhttp3.mockwebserver.RecordedRequest;
26+
import org.junit.jupiter.api.AfterEach;
27+
import org.junit.jupiter.api.BeforeEach;
28+
import org.junit.jupiter.api.Test;
29+
30+
import org.springframework.ai.chat.prompt.Prompt;
31+
import org.springframework.ai.openai.OpenAiChatModel;
32+
import org.springframework.ai.openai.OpenAiChatOptions;
33+
import org.springframework.ai.openai.api.OpenAiApi;
34+
import org.springframework.http.HttpHeaders;
35+
import org.springframework.http.MediaType;
36+
37+
import static org.assertj.core.api.Assertions.assertThat;
38+
39+
/**
40+
* Test to verify that extraBody parameters are correctly included in the HTTP request
41+
* sent to the OpenAI API. This test captures the actual wire-level JSON to verify the
42+
* end-to-end flow from OpenAiChatOptions through to the HTTP request body.
43+
*
44+
* <p>
45+
* These tests ensure that extraBody fields are:
46+
* <ul>
47+
* <li>Correctly merged from OpenAiChatOptions into ChatCompletionRequest</li>
48+
* <li>Flattened to the top level of the JSON (not nested under "extra_body")</li>
49+
* <li>Properly handled when set in default options, runtime options, or both</li>
50+
* </ul>
51+
*
52+
* @author Mark Pollack
53+
* @see <a href="https://github.com/spring-projects/spring-ai/issues/4867">GitHub Issue
54+
* #4867</a>
55+
*/
56+
class ExtraBodyWireTest {
57+
58+
private MockWebServer mockWebServer;
59+
60+
private final ObjectMapper objectMapper = new ObjectMapper();
61+
62+
@BeforeEach
63+
void setUp() throws Exception {
64+
this.mockWebServer = new MockWebServer();
65+
this.mockWebServer.start();
66+
}
67+
68+
@AfterEach
69+
void tearDown() throws Exception {
70+
this.mockWebServer.shutdown();
71+
}
72+
73+
@Test
74+
void extraBodyFromRuntimeOptionsAppearsInHttpRequest() throws Exception {
75+
// Arrange: Mock response
76+
this.mockWebServer.enqueue(createMockResponse());
77+
78+
OpenAiApi api = OpenAiApi.builder().apiKey("test-key").baseUrl(this.mockWebServer.url("/").toString()).build();
79+
80+
OpenAiChatModel chatModel = OpenAiChatModel.builder()
81+
.openAiApi(api)
82+
.defaultOptions(OpenAiChatOptions.builder().model("gpt-4").build())
83+
.build();
84+
85+
// Act: Call with extraBody in runtime options
86+
OpenAiChatOptions runtimeOptions = OpenAiChatOptions.builder()
87+
.extraBody(Map.of("top_k", 50, "repetition_penalty", 1.1))
88+
.build();
89+
90+
chatModel.call(new Prompt("Hello", runtimeOptions));
91+
92+
// Assert: Verify the wire-level JSON contains flattened extraBody fields
93+
RecordedRequest recordedRequest = this.mockWebServer.takeRequest();
94+
String requestBody = recordedRequest.getBody().readUtf8();
95+
JsonNode json = this.objectMapper.readTree(requestBody);
96+
97+
// Verify extraBody fields are at top level
98+
assertThat(json.has("top_k")).as("top_k should be at top level").isTrue();
99+
assertThat(json.get("top_k").asInt()).isEqualTo(50);
100+
assertThat(json.has("repetition_penalty")).as("repetition_penalty should be at top level").isTrue();
101+
assertThat(json.get("repetition_penalty").asDouble()).isEqualTo(1.1);
102+
103+
// Verify extra_body is NOT a nested object (fields are flattened)
104+
assertThat(json.has("extra_body")).as("extra_body should NOT appear as nested object").isFalse();
105+
}
106+
107+
@Test
108+
void extraBodyFromDefaultOptionsAppearsInHttpRequest() throws Exception {
109+
// Arrange: Mock response
110+
this.mockWebServer.enqueue(createMockResponse());
111+
112+
OpenAiApi api = OpenAiApi.builder().apiKey("test-key").baseUrl(this.mockWebServer.url("/").toString()).build();
113+
114+
// Set extraBody in DEFAULT options
115+
OpenAiChatModel chatModel = OpenAiChatModel.builder()
116+
.openAiApi(api)
117+
.defaultOptions(OpenAiChatOptions.builder()
118+
.model("gpt-4")
119+
.extraBody(Map.of("enable_thinking", true, "top_k", 40))
120+
.build())
121+
.build();
122+
123+
// Act: Call without runtime options
124+
chatModel.call(new Prompt("Hello"));
125+
126+
// Assert: Verify wire-level JSON
127+
RecordedRequest recordedRequest = this.mockWebServer.takeRequest();
128+
String requestBody = recordedRequest.getBody().readUtf8();
129+
JsonNode json = this.objectMapper.readTree(requestBody);
130+
131+
assertThat(json.has("enable_thinking")).isTrue();
132+
assertThat(json.get("enable_thinking").asBoolean()).isTrue();
133+
assertThat(json.has("top_k")).isTrue();
134+
assertThat(json.get("top_k").asInt()).isEqualTo(40);
135+
136+
// Verify extra_body is NOT a nested object
137+
assertThat(json.has("extra_body")).as("extra_body should NOT appear as nested object").isFalse();
138+
}
139+
140+
@Test
141+
void runtimeExtraBodyOverridesDefaultExtraBody() throws Exception {
142+
// Arrange
143+
this.mockWebServer.enqueue(createMockResponse());
144+
145+
OpenAiApi api = OpenAiApi.builder().apiKey("test-key").baseUrl(this.mockWebServer.url("/").toString()).build();
146+
147+
OpenAiChatModel chatModel = OpenAiChatModel.builder()
148+
.openAiApi(api)
149+
.defaultOptions(OpenAiChatOptions.builder()
150+
.model("gpt-4")
151+
.extraBody(Map.of("top_k", 40, "default_only", "value"))
152+
.build())
153+
.build();
154+
155+
// Act: Runtime extraBody should override default for same key
156+
OpenAiChatOptions runtimeOptions = OpenAiChatOptions.builder()
157+
.extraBody(Map.of("top_k", 100, "runtime_only", "value"))
158+
.build();
159+
160+
chatModel.call(new Prompt("Hello", runtimeOptions));
161+
162+
// Assert
163+
RecordedRequest recordedRequest = this.mockWebServer.takeRequest();
164+
String requestBody = recordedRequest.getBody().readUtf8();
165+
JsonNode json = this.objectMapper.readTree(requestBody);
166+
167+
// Runtime overrides default
168+
assertThat(json.get("top_k").asInt()).isEqualTo(100);
169+
// Both unique keys present
170+
assertThat(json.has("default_only")).isTrue();
171+
assertThat(json.has("runtime_only")).isTrue();
172+
173+
// Verify extra_body is NOT a nested object
174+
assertThat(json.has("extra_body")).as("extra_body should NOT appear as nested object").isFalse();
175+
}
176+
177+
@Test
178+
void extraBodyWithVllmParameters() throws Exception {
179+
// Arrange: Test with real vLLM parameters
180+
this.mockWebServer.enqueue(createMockResponse());
181+
182+
OpenAiApi api = OpenAiApi.builder().apiKey("test-key").baseUrl(this.mockWebServer.url("/").toString()).build();
183+
184+
OpenAiChatModel chatModel = OpenAiChatModel.builder()
185+
.openAiApi(api)
186+
.defaultOptions(OpenAiChatOptions.builder().model("meta-llama/Llama-3-8B-Instruct").build())
187+
.build();
188+
189+
// Act: Use real vLLM parameters
190+
OpenAiChatOptions runtimeOptions = OpenAiChatOptions.builder()
191+
.extraBody(Map.of("top_k", 50, "min_p", 0.05, "repetition_penalty", 1.1, "best_of", 3))
192+
.build();
193+
194+
chatModel.call(new Prompt("Hello", runtimeOptions));
195+
196+
// Assert
197+
RecordedRequest recordedRequest = this.mockWebServer.takeRequest();
198+
String requestBody = recordedRequest.getBody().readUtf8();
199+
JsonNode json = this.objectMapper.readTree(requestBody);
200+
201+
// All vLLM parameters should be at top level
202+
assertThat(json.get("top_k").asInt()).isEqualTo(50);
203+
assertThat(json.get("min_p").asDouble()).isEqualTo(0.05);
204+
assertThat(json.get("repetition_penalty").asDouble()).isEqualTo(1.1);
205+
assertThat(json.get("best_of").asInt()).isEqualTo(3);
206+
207+
// Verify model is also set correctly
208+
assertThat(json.get("model").asText()).isEqualTo("meta-llama/Llama-3-8B-Instruct");
209+
}
210+
211+
private MockResponse createMockResponse() {
212+
return new MockResponse().setResponseCode(200)
213+
.addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
214+
.setBody("""
215+
{
216+
"id": "chatcmpl-123",
217+
"object": "chat.completion",
218+
"created": 1677652288,
219+
"model": "gpt-4",
220+
"choices": [{
221+
"index": 0,
222+
"message": {"role": "assistant", "content": "Hello!"},
223+
"finish_reason": "stop"
224+
}],
225+
"usage": {"prompt_tokens": 9, "completion_tokens": 2, "total_tokens": 11}
226+
}
227+
""");
228+
}
229+
230+
}

0 commit comments

Comments
 (0)