Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import json
import re
from collections.abc import AsyncIterator, Iterator
from typing import Any, Literal, Optional, Union, get_args

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message
from haystack.dataclasses import ChatMessage, ComponentInfo, ImageContent, TextContent, ToolCall
from haystack.dataclasses import ChatMessage, ComponentInfo, ImageContent, ReasoningContent, TextContent, ToolCall
from haystack.dataclasses.streaming_chunk import (
AsyncStreamingCallbackT,
FinishReason,
Expand Down Expand Up @@ -184,6 +185,7 @@ def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage:
Extracts and organizes various response components including:
- Text content
- Tool calls
- Reasoning content (via native Cohere API with text-based fallback)
- Usage statistics
- Citations
- Metadata
Expand All @@ -192,6 +194,29 @@ def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage:
:param model: The name of the model that generated the response.
:return: A Haystack ChatMessage containing the formatted response.
"""
# Try to extract reasoning content using Cohere's native API (preferred method)
reasoning_content = None
text_content = ""

if chat_response.message.content:
for content_item in chat_response.message.content:
# Access thinking content via native Cohere API
if hasattr(content_item, "type") and content_item.type == "thinking":
if hasattr(content_item, "thinking") and content_item.thinking:
reasoning_content = ReasoningContent(reasoning_text=content_item.thinking)
# Access text content
elif hasattr(content_item, "type") and content_item.type == "text":
if hasattr(content_item, "text") and content_item.text:
text_content = content_item.text

# Fallback: If reasoning wasn't found via native API but text contains reasoning markers,
# extract it from text (for backward compatibility)
if reasoning_content is None and text_content:
fallback_reasoning, cleaned_text = _extract_reasoning_from_text(text_content)
if fallback_reasoning is not None:
reasoning_content = fallback_reasoning
text_content = cleaned_text

if chat_response.message.tool_calls:
tool_calls = []
for tc in chat_response.message.tool_calls:
Expand All @@ -206,9 +231,9 @@ def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage:

# Create message with tool plan as text and tool calls in the format Haystack expects
tool_plan = chat_response.message.tool_plan or ""
message = ChatMessage.from_assistant(text=tool_plan, tool_calls=tool_calls)
elif chat_response.message.content and hasattr(chat_response.message.content[0], "text"):
message = ChatMessage.from_assistant(chat_response.message.content[0].text)
message = ChatMessage.from_assistant(text=tool_plan, tool_calls=tool_calls, reasoning=reasoning_content)
elif text_content:
message = ChatMessage.from_assistant(text_content, reasoning=reasoning_content)
else:
# Handle the case where neither tool_calls nor content exists
logger.warning(f"Received empty response from Cohere API: {chat_response.message}")
Expand Down Expand Up @@ -352,6 +377,103 @@ def _convert_cohere_chunk_to_streaming_chunk(
)


def _extract_reasoning_from_text(response_text: str) -> tuple[Optional[ReasoningContent], str]:
"""
Extract reasoning content from text as a fallback method.

This is used when reasoning is not available via native Cohere API
(e.g., in streaming mode or for backward compatibility).

:param response_text: The raw response text from Cohere
:returns: A tuple of (ReasoningContent or None, cleaned_response_text)
"""
if not response_text or not isinstance(response_text, str):
return None, response_text

# Pattern 1: Look for thinking/reasoning tags
thinking_patterns = [
r"<thinking>(.*?)</thinking>",
r"<reasoning>(.*?)</reasoning>",
r"## Reasoning\s*\n(.*?)(?=\n## |$)",
r"## Thinking\s*\n(.*?)(?=\n## |$)",
]

for pattern in thinking_patterns:
match = re.search(pattern, response_text, re.DOTALL | re.IGNORECASE)
if match:
reasoning_text = match.group(1).strip()
cleaned_content = re.sub(pattern, "", response_text, flags=re.DOTALL | re.IGNORECASE).strip()
min_reasoning_length = 30
if len(reasoning_text) > min_reasoning_length:
return ReasoningContent(reasoning_text=reasoning_text), cleaned_content
else:
return None, cleaned_content

# Pattern 2: Look for step-by-step reasoning at start
lines = response_text.split("\n")
max_lines_to_check = 10
for i, line in enumerate(lines):
stripped_line = line.strip()
if (
stripped_line.startswith(("Step ", "First,", "Let me think", "I need to solve", "To solve"))
or stripped_line.startswith(("## Reasoning", "## Thinking", "## My reasoning"))
or (
len(stripped_line) > 0
and stripped_line.endswith(":")
and ("reasoning" in stripped_line.lower() or "thinking" in stripped_line.lower())
)
):
reasoning_end = len(lines)
for j in range(i + 1, len(lines)):
next_line = lines[j].strip()
if next_line.startswith(
("Based on", "Therefore", "In conclusion", "So,", "Thus,", "## Solution", "## Answer")
):
reasoning_end = j
break

reasoning_lines = lines[:reasoning_end]
content_lines = lines[reasoning_end:]
reasoning_text = "\n".join(reasoning_lines).strip()
cleaned_content = "\n".join(content_lines).strip()
min_reasoning_length = 30
if len(reasoning_text) > min_reasoning_length:
return ReasoningContent(reasoning_text=reasoning_text), cleaned_content
break

if i > max_lines_to_check: # Stop looking after first few lines
break

return None, response_text


def _convert_streaming_chunks_to_chat_message_with_reasoning(chunks: list[StreamingChunk]) -> ChatMessage:
"""
Convert streaming chunks to ChatMessage with reasoning extraction support.

For streaming, reasoning might not come via native API, so we use text-based extraction.
"""
base_message = _convert_streaming_chunks_to_chat_message(chunks=chunks)

if not base_message.text:
return base_message

# Try to extract reasoning from text (fallback for streaming)
reasoning_content, cleaned_text = _extract_reasoning_from_text(base_message.text)

if reasoning_content is None:
return base_message

new_message = ChatMessage.from_assistant(
text=cleaned_text,
reasoning=reasoning_content,
tool_calls=base_message.tool_calls,
meta=base_message.meta,
)

return new_message


def _parse_streaming_response(
response: Iterator[StreamedChatResponseV2],
model: str,
Expand Down Expand Up @@ -383,7 +505,7 @@ def _parse_streaming_response(
chunks.append(streaming_chunk)
streaming_callback(streaming_chunk)

return _convert_streaming_chunks_to_chat_message(chunks=chunks)
return _convert_streaming_chunks_to_chat_message_with_reasoning(chunks=chunks)


async def _parse_async_streaming_response(
Expand Down Expand Up @@ -411,7 +533,7 @@ async def _parse_async_streaming_response(
chunks.append(streaming_chunk)
await streaming_callback(streaming_chunk)

return _convert_streaming_chunks_to_chat_message(chunks=chunks)
return _convert_streaming_chunks_to_chat_message_with_reasoning(chunks=chunks)


@component
Expand Down
Loading