Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/fine-symbols-jam.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@openai/agents-extensions': patch
---

fix: preserve Gemini thought_signature in multi-turn tool calls
7 changes: 6 additions & 1 deletion packages/agents-extensions/src/aiSdk.ts
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,9 @@ export class AiSdkModel implements Model {
name: toolCall.toolName,
arguments: toolCallArguments,
status: 'completed',
providerData: hasToolCalls ? result.providerMetadata : undefined,
providerData:
toolCall.providerMetadata ??
(hasToolCalls ? result.providerMetadata : undefined),
});
}

Expand Down Expand Up @@ -916,6 +918,9 @@ export class AiSdkModel implements Model {
name: (part as any).toolName,
arguments: (part as any).input ?? '',
status: 'completed',
...((part as any).providerMetadata
? { providerData: (part as any).providerMetadata }
: {}),
};
}
break;
Expand Down
210 changes: 210 additions & 0 deletions packages/agents-extensions/test/aiSdk.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,106 @@ describe('AiSdkModel.getResponse', () => {
]);
});

test('preserves per-tool-call providerMetadata (e.g., Gemini thoughtSignature)', async () => {
const toolCallProviderMetadata = {
google: { thoughtSignature: 'sig123' },
};
const resultProviderMetadata = {
google: { usageMetadata: { totalTokenCount: 100 } },
};

const model = new AiSdkModel(
stubModel({
async doGenerate() {
return {
content: [
{
type: 'tool-call',
toolCallId: 'c1',
toolName: 'get_weather',
input: { location: 'Tokyo' },
providerMetadata: toolCallProviderMetadata,
},
],
usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30 },
providerMetadata: resultProviderMetadata,
response: { id: 'resp-1' },
finishReason: 'tool-calls',
warnings: [],
} as any;
},
}),
);

const res = await withTrace('t', () =>
model.getResponse({
input: 'What is the weather in Tokyo?',
tools: [
{
type: 'function',
name: 'get_weather',
description: 'Get weather',
parameters: { type: 'object', properties: {} },
},
],
handoffs: [],
modelSettings: {},
outputType: 'text',
tracing: false,
} as any),
);

expect(res.output).toHaveLength(1);
expect(res.output[0]).toMatchObject({
type: 'function_call',
callId: 'c1',
name: 'get_weather',
providerData: toolCallProviderMetadata,
});
// Ensure we get per-tool-call metadata, not result-level metadata
expect(res.output[0].providerData).not.toEqual(resultProviderMetadata);
});

test('falls back to result.providerMetadata when toolCall.providerMetadata is undefined', async () => {
const resultProviderMetadata = { fallback: true };

const model = new AiSdkModel(
stubModel({
async doGenerate() {
return {
content: [
{
type: 'tool-call',
toolCallId: 'c1',
toolName: 'foo',
input: {},
// No providerMetadata on tool call
},
],
usage: { inputTokens: 1, outputTokens: 2, totalTokens: 3 },
providerMetadata: resultProviderMetadata,
response: { id: 'id' },
finishReason: 'tool-calls',
warnings: [],
} as any;
},
}),
);

const res = await withTrace('t', () =>
model.getResponse({
input: 'hi',
tools: [],
handoffs: [],
modelSettings: {},
outputType: 'text',
tracing: false,
} as any),
);

expect(res.output[0].providerData).toEqual(resultProviderMetadata);
});

test('propagates errors', async () => {
const model = new AiSdkModel(
stubModel({
Expand Down Expand Up @@ -905,6 +1005,116 @@ describe('AiSdkModel.getStreamedResponse', () => {
]);
});

test('preserves per-tool-call providerMetadata in streaming mode (e.g., Gemini thoughtSignature)', async () => {
const toolCallProviderMetadata = {
google: { thoughtSignature: 'stream-sig-456' },
};

const parts = [
{
type: 'tool-call',
toolCallId: 'c1',
toolName: 'get_weather',
input: '{"location":"Tokyo"}',
providerMetadata: toolCallProviderMetadata,
},
{ type: 'response-metadata', id: 'resp-stream-1' },
{
type: 'finish',
finishReason: 'tool-calls',
usage: { inputTokens: 10, outputTokens: 20 },
},
];

const model = new AiSdkModel(
stubModel({
async doStream() {
return {
stream: partsStream(parts),
} as any;
},
}),
);

const events: any[] = [];
for await (const ev of model.getStreamedResponse({
input: 'What is the weather?',
tools: [
{
type: 'function',
name: 'get_weather',
description: 'Get weather',
parameters: { type: 'object', properties: {} },
},
],
handoffs: [],
modelSettings: {},
outputType: 'text',
tracing: false,
} as any)) {
events.push(ev);
}

const final = events.at(-1);
expect(final.type).toBe('response_done');
expect(final.response.output).toHaveLength(1);
expect(final.response.output[0]).toMatchObject({
type: 'function_call',
callId: 'c1',
name: 'get_weather',
providerData: toolCallProviderMetadata,
});
});

test('omits providerData in streaming mode when providerMetadata is not present', async () => {
const parts = [
{
type: 'tool-call',
toolCallId: 'c1',
toolName: 'foo',
input: '{}',
// No providerMetadata
},
{
type: 'finish',
finishReason: 'tool-calls',
usage: { inputTokens: 1, outputTokens: 2 },
},
];

const model = new AiSdkModel(
stubModel({
async doStream() {
return {
stream: partsStream(parts),
} as any;
},
}),
);

const events: any[] = [];
for await (const ev of model.getStreamedResponse({
input: 'hi',
tools: [],
handoffs: [],
modelSettings: {},
outputType: 'text',
tracing: false,
} as any)) {
events.push(ev);
}

const final = events.at(-1);
expect(final.type).toBe('response_done');
expect(final.response.output[0]).toMatchObject({
type: 'function_call',
callId: 'c1',
name: 'foo',
});
// providerData should not be present when providerMetadata was not provided
expect(final.response.output[0].providerData).toBeUndefined();
});

test('propagates stream errors', async () => {
const err = new Error('bad');
const parts = [{ type: 'error', error: err }];
Expand Down