Skip to content

Commit a1da64c

Browse files
authored
Support raw CoT reasoning from LM Studio and other OpenAI Responses-compatible APIs (#3559)
1 parent 84fdd7a commit a1da64c

23 files changed

+1324
-824
lines changed

docs/thinking.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ agent = Agent(model, model_settings=settings)
3838
...
3939
```
4040

41+
!!! note "Raw reasoning without summaries"
42+
Some OpenAI-compatible APIs (such as LM Studio, vLLM, or OpenRouter with gpt-oss models) may return raw reasoning content without reasoning summaries. In this case, [`ThinkingPart.content`][pydantic_ai.messages.ThinkingPart.content] will be empty, but the raw reasoning is available in `provider_details['raw_content']`. Following [OpenAI's guidance](https://cookbook.openai.com/examples/responses_api/reasoning_items) that raw reasoning should not be shown directly to users, we store it in `provider_details` rather than in the main `content` field.
43+
4144
## Anthropic
4245

4346
To enable thinking, use the [`AnthropicModelSettings.anthropic_thinking`][pydantic_ai.models.anthropic.AnthropicModelSettings.anthropic_thinking] [model setting](agents.md#model-run-settings).

pydantic_ai_slim/pydantic_ai/_parts_manager.py

Lines changed: 105 additions & 76 deletions
Large diffs are not rendered by default.

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import base64
44
import hashlib
55
from abc import ABC, abstractmethod
6-
from collections.abc import Sequence
6+
from collections.abc import Callable, Sequence
77
from dataclasses import KW_ONLY, dataclass, field, replace
88
from datetime import datetime
99
from mimetypes import guess_type
@@ -64,6 +64,9 @@
6464
]
6565
"""Reason the model finished generating the response, normalized to OpenTelemetry values."""
6666

67+
ProviderDetailsDelta: TypeAlias = dict[str, Any] | Callable[[dict[str, Any] | None], dict[str, Any]] | None
68+
"""Type for provider_details input: can be a static dict, a callback to update existing details, or None."""
69+
6770

6871
@dataclass(repr=False)
6972
class SystemPromptPart:
@@ -1525,9 +1528,12 @@ class ThinkingPartDelta:
15251528
Signatures are only sent back to the same provider.
15261529
"""
15271530

1528-
provider_details: dict[str, Any] | None = None
1531+
provider_details: ProviderDetailsDelta = None
15291532
"""Additional data returned by the provider that can't be mapped to standard fields.
15301533
1534+
Can be a dict to merge with existing details, or a callable that takes
1535+
the existing details and returns updated details.
1536+
15311537
This is used for data that is required to be sent back to APIs, as well as data users may want to access programmatically."""
15321538

15331539
part_delta_kind: Literal['thinking'] = 'thinking'
@@ -1555,7 +1561,13 @@ def apply(self, part: ModelResponsePart | ThinkingPartDelta) -> ThinkingPart | T
15551561
new_content = part.content + self.content_delta if self.content_delta else part.content
15561562
new_signature = self.signature_delta if self.signature_delta is not None else part.signature
15571563
new_provider_name = self.provider_name if self.provider_name is not None else part.provider_name
1558-
new_provider_details = {**(part.provider_details or {}), **(self.provider_details or {})} or None
1564+
# Resolve callable provider_details if needed
1565+
resolved_details = (
1566+
self.provider_details(part.provider_details)
1567+
if callable(self.provider_details)
1568+
else self.provider_details
1569+
)
1570+
new_provider_details = {**(part.provider_details or {}), **(resolved_details or {})} or None
15591571
return replace(
15601572
part,
15611573
content=new_content,
@@ -1573,7 +1585,28 @@ def apply(self, part: ModelResponsePart | ThinkingPartDelta) -> ThinkingPart | T
15731585
if self.provider_name is not None:
15741586
part = replace(part, provider_name=self.provider_name)
15751587
if self.provider_details is not None:
1576-
part = replace(part, provider_details={**(part.provider_details or {}), **self.provider_details})
1588+
if callable(self.provider_details):
1589+
if callable(part.provider_details):
1590+
existing_fn = part.provider_details
1591+
new_fn = self.provider_details
1592+
1593+
def chained_both(d: dict[str, Any] | None) -> dict[str, Any]:
1594+
return new_fn(existing_fn(d))
1595+
1596+
part = replace(part, provider_details=chained_both)
1597+
else:
1598+
part = replace(part, provider_details=self.provider_details)
1599+
elif callable(part.provider_details):
1600+
existing_fn = part.provider_details
1601+
new_dict = self.provider_details
1602+
1603+
def chained_dict(d: dict[str, Any] | None) -> dict[str, Any]:
1604+
return {**existing_fn(d), **new_dict}
1605+
1606+
part = replace(part, provider_details=chained_dict)
1607+
else:
1608+
existing = part.provider_details if isinstance(part.provider_details, dict) else {}
1609+
part = replace(part, provider_details={**existing, **self.provider_details})
15771610
return part
15781611
raise ValueError( # pragma: no cover
15791612
f'Cannot apply ThinkingPartDeltas to non-ThinkingParts or non-ThinkingPartDeltas ({part=}, {self=})'

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,25 +1129,26 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
11291129
elif isinstance(event, BetaRawContentBlockStartEvent):
11301130
current_block = event.content_block
11311131
if isinstance(current_block, BetaTextBlock) and current_block.text:
1132-
maybe_event = self._parts_manager.handle_text_delta(
1132+
for event_ in self._parts_manager.handle_text_delta(
11331133
vendor_part_id=event.index, content=current_block.text
1134-
)
1135-
if maybe_event is not None: # pragma: no branch
1136-
yield maybe_event
1134+
):
1135+
yield event_
11371136
elif isinstance(current_block, BetaThinkingBlock):
1138-
yield self._parts_manager.handle_thinking_delta(
1137+
for event_ in self._parts_manager.handle_thinking_delta(
11391138
vendor_part_id=event.index,
11401139
content=current_block.thinking,
11411140
signature=current_block.signature,
11421141
provider_name=self.provider_name,
1143-
)
1142+
):
1143+
yield event_
11441144
elif isinstance(current_block, BetaRedactedThinkingBlock):
1145-
yield self._parts_manager.handle_thinking_delta(
1145+
for event_ in self._parts_manager.handle_thinking_delta(
11461146
vendor_part_id=event.index,
11471147
id='redacted_thinking',
11481148
signature=current_block.data,
11491149
provider_name=self.provider_name,
1150-
)
1150+
):
1151+
yield event_
11511152
elif isinstance(current_block, BetaToolUseBlock):
11521153
maybe_event = self._parts_manager.handle_tool_call_delta(
11531154
vendor_part_id=event.index,
@@ -1208,23 +1209,24 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
12081209

12091210
elif isinstance(event, BetaRawContentBlockDeltaEvent):
12101211
if isinstance(event.delta, BetaTextDelta):
1211-
maybe_event = self._parts_manager.handle_text_delta(
1212+
for event_ in self._parts_manager.handle_text_delta(
12121213
vendor_part_id=event.index, content=event.delta.text
1213-
)
1214-
if maybe_event is not None: # pragma: no branch
1215-
yield maybe_event
1214+
):
1215+
yield event_
12161216
elif isinstance(event.delta, BetaThinkingDelta):
1217-
yield self._parts_manager.handle_thinking_delta(
1217+
for event_ in self._parts_manager.handle_thinking_delta(
12181218
vendor_part_id=event.index,
12191219
content=event.delta.thinking,
12201220
provider_name=self.provider_name,
1221-
)
1221+
):
1222+
yield event_
12221223
elif isinstance(event.delta, BetaSignatureDelta):
1223-
yield self._parts_manager.handle_thinking_delta(
1224+
for event_ in self._parts_manager.handle_thinking_delta(
12241225
vendor_part_id=event.index,
12251226
signature=event.delta.signature,
12261227
provider_name=self.provider_name,
1227-
)
1228+
):
1229+
yield event_
12281230
elif isinstance(event.delta, BetaInputJSONDelta):
12291231
maybe_event = self._parts_manager.handle_tool_call_delta(
12301232
vendor_part_id=event.index,

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -751,24 +751,25 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
751751
delta = content_block_delta['delta']
752752
if 'reasoningContent' in delta:
753753
if redacted_content := delta['reasoningContent'].get('redactedContent'):
754-
yield self._parts_manager.handle_thinking_delta(
754+
for event in self._parts_manager.handle_thinking_delta(
755755
vendor_part_id=index,
756756
id='redacted_content',
757757
signature=redacted_content.decode('utf-8'),
758758
provider_name=self.provider_name,
759-
)
759+
):
760+
yield event
760761
else:
761762
signature = delta['reasoningContent'].get('signature')
762-
yield self._parts_manager.handle_thinking_delta(
763+
for event in self._parts_manager.handle_thinking_delta(
763764
vendor_part_id=index,
764765
content=delta['reasoningContent'].get('text'),
765766
signature=signature,
766767
provider_name=self.provider_name if signature else None,
767-
)
768+
):
769+
yield event
768770
if text := delta.get('text'):
769-
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=text)
770-
if maybe_event is not None: # pragma: no branch
771-
yield maybe_event
771+
for event in self._parts_manager.handle_text_delta(vendor_part_id=index, content=text):
772+
yield event
772773
if 'toolUse' in delta:
773774
tool_use = delta['toolUse']
774775
maybe_event = self._parts_manager.handle_tool_call_delta(

pydantic_ai_slim/pydantic_ai/models/function.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -292,26 +292,26 @@ class FunctionStreamedResponse(StreamedResponse):
292292
def __post_init__(self):
293293
self._usage += _estimate_usage([])
294294

295-
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
295+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
296296
async for item in self._iter:
297297
if isinstance(item, str):
298298
response_tokens = _estimate_string_tokens(item)
299299
self._usage += usage.RequestUsage(output_tokens=response_tokens)
300-
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=item)
301-
if maybe_event is not None: # pragma: no branch
302-
yield maybe_event
300+
for event in self._parts_manager.handle_text_delta(vendor_part_id='content', content=item):
301+
yield event
303302
elif isinstance(item, dict) and item:
304303
for dtc_index, delta in item.items():
305304
if isinstance(delta, DeltaThinkingPart):
306305
if delta.content: # pragma: no branch
307306
response_tokens = _estimate_string_tokens(delta.content)
308307
self._usage += usage.RequestUsage(output_tokens=response_tokens)
309-
yield self._parts_manager.handle_thinking_delta(
308+
for event in self._parts_manager.handle_thinking_delta(
310309
vendor_part_id=dtc_index,
311310
content=delta.content,
312311
signature=delta.signature,
313312
provider_name='function' if delta.signature else None,
314-
)
313+
):
314+
yield event
315315
elif isinstance(delta, DeltaToolCall):
316316
if delta.json_args:
317317
response_tokens = _estimate_string_tokens(delta.json_args)

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -465,11 +465,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
465465
if 'text' in gemini_part:
466466
# Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled
467467
# amongst the tool call deltas
468-
maybe_event = self._parts_manager.handle_text_delta(
468+
for event in self._parts_manager.handle_text_delta(
469469
vendor_part_id=None, content=gemini_part['text']
470-
)
471-
if maybe_event is not None: # pragma: no branch
472-
yield maybe_event
470+
):
471+
yield event
473472

474473
elif 'function_call' in gemini_part:
475474
# Here, we assume all function_call parts are complete and don't have deltas.

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -722,15 +722,15 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
722722
if len(part.text) == 0 and not provider_details:
723723
continue
724724
if part.thought:
725-
yield self._parts_manager.handle_thinking_delta(
725+
for event in self._parts_manager.handle_thinking_delta(
726726
vendor_part_id=None, content=part.text, provider_details=provider_details
727-
)
727+
):
728+
yield event
728729
else:
729-
maybe_event = self._parts_manager.handle_text_delta(
730+
for event in self._parts_manager.handle_text_delta(
730731
vendor_part_id=None, content=part.text, provider_details=provider_details
731-
)
732-
if maybe_event is not None: # pragma: no branch
733-
yield maybe_event
732+
):
733+
yield event
734734
elif part.function_call:
735735
maybe_event = self._parts_manager.handle_tool_call_delta(
736736
vendor_part_id=uuid4(),

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -551,9 +551,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
551551
reasoning = True
552552

553553
# NOTE: The `reasoning` field is only present if `groq_reasoning_format` is set to `parsed`.
554-
yield self._parts_manager.handle_thinking_delta(
554+
for event in self._parts_manager.handle_thinking_delta(
555555
vendor_part_id=f'reasoning-{reasoning_index}', content=choice.delta.reasoning
556-
)
556+
):
557+
yield event
557558
else:
558559
reasoning = False
559560

@@ -576,14 +577,13 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
576577
# Handle the text part of the response
577578
content = choice.delta.content
578579
if content:
579-
maybe_event = self._parts_manager.handle_text_delta(
580+
for event in self._parts_manager.handle_text_delta(
580581
vendor_part_id='content',
581582
content=content,
582583
thinking_tags=self._model_profile.thinking_tags,
583584
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
584-
)
585-
if maybe_event is not None: # pragma: no branch
586-
yield maybe_event
585+
):
586+
yield event
587587

588588
# Handle the tool calls
589589
for dtc in choice.delta.tool_calls or []:

pydantic_ai_slim/pydantic_ai/models/huggingface.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -487,14 +487,13 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
487487
# Handle the text part of the response
488488
content = choice.delta.content
489489
if content:
490-
maybe_event = self._parts_manager.handle_text_delta(
490+
for event in self._parts_manager.handle_text_delta(
491491
vendor_part_id='content',
492492
content=content,
493493
thinking_tags=self._model_profile.thinking_tags,
494494
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
495-
)
496-
if maybe_event is not None: # pragma: no branch
497-
yield maybe_event
495+
):
496+
yield event
498497

499498
for dtc in choice.delta.tool_calls or []:
500499
maybe_event = self._parts_manager.handle_tool_call_delta(

0 commit comments

Comments
 (0)