Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@ class CohereDocumentEmbedder:
Usage example:
```python
from haystack import Document
from haystack_integrations.components.embedders.cohere import CohereDocumentEmbedder
from haystack_integrations.components.embedders.cohere import (
CohereDocumentEmbedder,
)

doc = Document(content="I love pizza!")

document_embedder = CohereDocumentEmbedder()

result = document_embedder.run([doc])
print(result['documents'][0].embedding)
print(result["documents"][0].embedding)

# [-0.453125, 1.2236328, 2.0058594, ...]
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ class CohereDocumentImageEmbedder:
### Usage example
```python
from haystack import Document
from haystack_integrations.components.embedders.cohere import CohereDocumentImageEmbedder
from haystack_integrations.components.embedders.cohere import (
CohereDocumentImageEmbedder,
)

embedder = CohereDocumentImageEmbedder(model="embed-v4.0")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message
from haystack.dataclasses import ChatMessage, ComponentInfo, ImageContent, TextContent, ToolCall
from haystack.dataclasses import ChatMessage, ComponentInfo, ImageContent, ReasoningContent, TextContent, ToolCall
from haystack.dataclasses.streaming_chunk import (
AsyncStreamingCallbackT,
FinishReason,
Expand Down Expand Up @@ -184,6 +184,7 @@ def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage:
Extracts and organizes various response components including:
- Text content
- Tool calls
- Reasoning content (via native Cohere API)
- Usage statistics
- Citations
- Metadata
Expand All @@ -192,6 +193,21 @@ def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage:
:param model: The name of the model that generated the response.
:return: A Haystack ChatMessage containing the formatted response.
"""
# Extract reasoning content using Cohere's native API
reasoning_content = None
text_content = ""

if chat_response.message.content:
for content_item in chat_response.message.content:
# Access thinking content via native Cohere API
if hasattr(content_item, "type") and content_item.type == "thinking":
if hasattr(content_item, "thinking") and content_item.thinking:
reasoning_content = ReasoningContent(reasoning_text=content_item.thinking)
# Access text content
elif hasattr(content_item, "type") and content_item.type == "text":
if hasattr(content_item, "text") and content_item.text:
text_content = content_item.text

if chat_response.message.tool_calls:
tool_calls = []
for tc in chat_response.message.tool_calls:
Expand All @@ -206,9 +222,9 @@ def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage:

# Create message with tool plan as text and tool calls in the format Haystack expects
tool_plan = chat_response.message.tool_plan or ""
message = ChatMessage.from_assistant(text=tool_plan, tool_calls=tool_calls)
elif chat_response.message.content and hasattr(chat_response.message.content[0], "text"):
message = ChatMessage.from_assistant(chat_response.message.content[0].text)
message = ChatMessage.from_assistant(text=tool_plan, tool_calls=tool_calls, reasoning=reasoning_content)
elif text_content:
message = ChatMessage.from_assistant(text_content, reasoning=reasoning_content)
else:
# Handle the case where neither tool_calls nor content exists
logger.warning(f"Received empty response from Cohere API: {chat_response.message}")
Expand Down Expand Up @@ -437,9 +453,13 @@ class CohereChatGenerator:
```python
from haystack.dataclasses import ChatMessage
from haystack.utils import Secret
from haystack_integrations.components.generators.cohere import CohereChatGenerator
from haystack_integrations.components.generators.cohere import (
CohereChatGenerator,
)

client = CohereChatGenerator(model="command-r-08-2024", api_key=Secret.from_env_var("COHERE_API_KEY"))
client = CohereChatGenerator(
model="command-r-08-2024", api_key=Secret.from_env_var("COHERE_API_KEY")
)
messages = [ChatMessage.from_user("What's Natural Language Processing?")]
client.run(messages)

Expand All @@ -451,16 +471,25 @@ class CohereChatGenerator:
```python
from haystack.dataclasses import ChatMessage, ImageContent
from haystack.utils import Secret
from haystack_integrations.components.generators.cohere import CohereChatGenerator
from haystack_integrations.components.generators.cohere import (
CohereChatGenerator,
)

# Create an image from file path or base64
image_content = ImageContent.from_file_path("path/to/your/image.jpg")

# Create a multimodal message with both text and image
messages = [ChatMessage.from_user(content_parts=["What's in this image?", image_content])]
messages = [
ChatMessage.from_user(
content_parts=["What's in this image?", image_content]
)
]

# Use a multimodal model like Command A Vision
client = CohereChatGenerator(model="command-a-vision-07-2025", api_key=Secret.from_env_var("COHERE_API_KEY"))
client = CohereChatGenerator(
model="command-a-vision-07-2025",
api_key=Secret.from_env_var("COHERE_API_KEY"),
)
response = client.run(messages)
print(response)
```
Expand All @@ -475,12 +504,16 @@ class CohereChatGenerator:
from haystack.dataclasses import ChatMessage
from haystack.components.tools import ToolInvoker
from haystack.tools import Tool
from haystack_integrations.components.generators.cohere import CohereChatGenerator
from haystack_integrations.components.generators.cohere import (
CohereChatGenerator,
)


# Create a weather tool
def weather(city: str) -> str:
return f"The weather in {city} is sunny and 32°C"


weather_tool = Tool(
name="weather",
description="useful to determine the weather in a given location",
Expand All @@ -499,13 +532,22 @@ def weather(city: str) -> str:

# Create and set up the pipeline
pipeline = Pipeline()
pipeline.add_component("generator", CohereChatGenerator(model="command-r-08-2024", tools=[weather_tool]))
pipeline.add_component(
"generator",
CohereChatGenerator(model="command-r-08-2024", tools=[weather_tool]),
)
pipeline.add_component("tool_invoker", ToolInvoker(tools=[weather_tool]))
pipeline.connect("generator", "tool_invoker")

# Run the pipeline with a weather query
results = pipeline.run(
data={"generator": {"messages": [ChatMessage.from_user("What's the weather like in Paris?")]}}
data={
"generator": {
"messages": [
ChatMessage.from_user("What's the weather like in Paris?")
]
}
}
)

# The tool result will be available in the pipeline output
Expand Down
168 changes: 161 additions & 7 deletions integrations/cohere/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@
from haystack import Pipeline
from haystack.components.generators.utils import print_streaming_chunk
from haystack.components.tools import ToolInvoker
from haystack.dataclasses import ChatMessage, ChatRole, ImageContent, ToolCall
from haystack.dataclasses import ChatMessage, ChatRole, ImageContent, ReasoningContent, ToolCall
from haystack.dataclasses.streaming_chunk import StreamingChunk
from haystack.tools import Tool, Toolset
from haystack.utils import Secret

from haystack_integrations.components.generators.cohere import CohereChatGenerator
from haystack_integrations.components.generators.cohere.chat.chat_generator import (
_format_message,
)
from haystack_integrations.components.generators.cohere.chat.chat_generator import _format_message


def weather(city: str) -> str:
Expand Down Expand Up @@ -444,11 +442,14 @@ def test_run_image(self):

generator = CohereChatGenerator(api_key=Secret.from_token("test-api-key"))

# Mock the client's chat method
# Mock the client's chat method with proper content structure
mock_response = MagicMock()
mock_response.message.content = [MagicMock()]
mock_response.message.content[0].text = "This is a test image response"
text_content = MagicMock()
text_content.type = "text"
text_content.text = "This is a test image response"
mock_response.message.content = [text_content]
mock_response.message.tool_calls = None
mock_response.message.citations = None
mock_response.finish_reason = "COMPLETE"
mock_response.usage = None

Expand Down Expand Up @@ -727,6 +728,159 @@ def test_live_run_multimodal(self):
assert isinstance(results["replies"][0], ChatMessage)
assert len(results["replies"][0].text) > 0


class TestCohereChatGeneratorReasoning:
"""Integration tests for reasoning functionality in CohereChatGenerator."""

@pytest.mark.skipif(not os.environ.get("COHERE_API_KEY"), reason="COHERE_API_KEY not set")
@pytest.mark.integration
def test_reasoning_with_command_a_reasoning_model(self):
"""Test reasoning extraction with Command A Reasoning model."""
generator = CohereChatGenerator(
model="command-a-reasoning-111b-2024-10-03",
generation_kwargs={"thinking": True}, # Enable reasoning
)

messages = [
ChatMessage.from_user("Solve this math problem step by step: What is the area of a circle with radius 7?")
]

result = generator.run(messages=messages)

assert "replies" in result
assert len(result["replies"]) == 1

reply = result["replies"][0]
assert isinstance(reply, ChatMessage)
assert reply.role == ChatRole.ASSISTANT

# Check if reasoning was extracted
if reply.reasoning:
assert isinstance(reply.reasoning, ReasoningContent)
assert len(reply.reasoning.reasoning_text) > 50 # Should have substantial reasoning

# The reasoning should contain mathematical thinking
reasoning_lower = reply.reasoning.reasoning_text.lower()
assert any(word in reasoning_lower for word in ["area", "circle", "radius", "formula", "π", "pi"])

# Check the main response content
assert len(reply.text) > 0
response_lower = reply.text.lower()
assert any(word in response_lower for word in ["area", "153.94", "154", "square"])

def test_reasoning_with_mock_response(self):
"""Test reasoning extraction with mocked Cohere response using native API."""
generator = CohereChatGenerator(
model="command-a-reasoning-111b-2024-10-03", api_key=Secret.from_token("fake-api-key")
)

# Mock the Cohere client response using native API structure
mock_response = MagicMock()

# Create mock content items with thinking and text types
thinking_content = MagicMock()
thinking_content.type = "thinking"
thinking_content.thinking = """I need to solve for the area of a circle.
The formula is A = πr²
With radius 7: A = π * 7² = π * 49 ≈ 153.94"""

text_content = MagicMock()
text_content.type = "text"
text_content.text = "The area of a circle with radius 7 is approximately 153.94 square units."

mock_response.message.content = [thinking_content, text_content]
mock_response.message.tool_calls = None
mock_response.message.citations = None
mock_response.finish_reason = "COMPLETE"
mock_response.usage = None

generator.client.chat = MagicMock(return_value=mock_response)

messages = [ChatMessage.from_user("What is the area of a circle with radius 7?")]
result = generator.run(messages=messages)

assert "replies" in result
assert len(result["replies"]) == 1

reply = result["replies"][0]
assert isinstance(reply, ChatMessage)
assert reply.role == ChatRole.ASSISTANT

# Check reasoning extraction via native API
assert reply.reasoning is not None
assert isinstance(reply.reasoning, ReasoningContent)
assert "formula is A = πr²" in reply.reasoning.reasoning_text
assert "π * 49 ≈ 153.94" in reply.reasoning.reasoning_text

# Check text content
assert reply.text.strip() == "The area of a circle with radius 7 is approximately 153.94 square units."

def test_reasoning_with_tool_calls_compatibility(self):
"""Test that reasoning works with tool calls."""
weather_tool = Tool(
name="weather",
description="Get weather for a city",
parameters={
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
function=weather,
)

generator = CohereChatGenerator(
model="command-a-reasoning-111b-2024-10-03", tools=[weather_tool], api_key=Secret.from_token("fake-api-key")
)

# Mock response with both reasoning and tool calls using native API
mock_response = MagicMock()

# Create mock content items with thinking type
thinking_content = MagicMock()
thinking_content.type = "thinking"
thinking_content.thinking = (
"The user is asking about weather in Paris. I should use the weather tool to get accurate information."
)

mock_response.message.content = [thinking_content]

# Mock tool call
mock_tool_call = MagicMock()
mock_tool_call.function.name = "weather"
mock_tool_call.function.arguments = '{"city": "Paris"}'
mock_tool_call.id = "call_123"
mock_response.message.tool_calls = [mock_tool_call]
mock_response.message.tool_plan = "I'll check the weather in Paris for you."
mock_response.message.citations = None
mock_response.finish_reason = "TOOL_CALLS"
mock_response.usage = None

generator.client.chat = MagicMock(return_value=mock_response)

messages = [ChatMessage.from_user("What's the weather like in Paris?")]
result = generator.run(messages=messages)

assert "replies" in result
assert len(result["replies"]) == 1

reply = result["replies"][0]
assert isinstance(reply, ChatMessage)

# Check reasoning extraction via native API
assert reply.reasoning is not None
assert isinstance(reply.reasoning, ReasoningContent)
assert "weather tool" in reply.reasoning.reasoning_text

# Check tool calls are preserved
assert reply.tool_calls is not None
assert len(reply.tool_calls) == 1
assert reply.tool_calls[0].tool_name == "weather"

# Check tool plan is used as text
assert "I'll check the weather in Paris" in reply.text

@pytest.mark.skipif(not os.environ.get("COHERE_API_KEY"), reason="COHERE_API_KEY not set")
@pytest.mark.integration
def test_live_run_with_mixed_tools(self):
"""
Integration test that verifies CohereChatGenerator works with mixed Tool and Toolset.
Expand Down