Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions rust/cocoindex/src/llm/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ impl LlmGenerationClient for Client {
}
let text = if let Some(json) = extracted_json {
// Try strict JSON serialization first
serde_json::to_string(&json)?
return Ok(LlmGenerateResponse::Json(json));
} else {
// Fallback: try text if no tool output found
match &mut resp_json["content"][0]["text"] {
Expand Down Expand Up @@ -155,7 +155,7 @@ impl LlmGenerationClient for Client {
}
};

Ok(LlmGenerateResponse { text })
Ok(LlmGenerateResponse::Text(text))
}

fn json_schema_options(&self) -> ToJsonSchemaOptions {
Expand Down
4 changes: 2 additions & 2 deletions rust/cocoindex/src/llm/bedrock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ impl LlmGenerationClient for Client {

if let Some(json) = extracted_json {
// Return the structured output as JSON
serde_json::to_string(&json)?
return Ok(LlmGenerateResponse::Json(json));
} else {
// Fall back to text content
let mut text_parts = Vec::new();
Expand All @@ -165,7 +165,7 @@ impl LlmGenerationClient for Client {
return Err(anyhow::anyhow!("No content found in Bedrock response"));
};

Ok(LlmGenerateResponse { text })
Ok(LlmGenerateResponse::Text(text))
}

fn json_schema_options(&self) -> ToJsonSchemaOptions {
Expand Down
32 changes: 28 additions & 4 deletions rust/cocoindex/src/llm/gemini.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,11 @@ impl LlmGenerationClient for AiStudioClient {
});
}

let mut need_json = false;

// If structured output is requested, add schema and responseMimeType
if let Some(OutputFormat::JsonSchema { schema, .. }) = &request.output_format {
need_json = true;
let mut schema_json = serde_json::to_value(schema)?;
remove_additional_properties(&mut schema_json);
payload["generationConfig"] = serde_json::json!({
Expand All @@ -161,18 +164,24 @@ impl LlmGenerationClient for AiStudioClient {
let resp = http::request(|| self.client.post(&url).json(&payload))
.await
.context("Gemini API error")?;
let resp_json: Value = resp.json().await.context("Invalid JSON")?;
let mut resp_json: Value = resp.json().await.context("Invalid JSON")?;

if let Some(error) = resp_json.get("error") {
bail!("Gemini API error: {:?}", error);
}
let mut resp_json = resp_json;

if need_json {
return Ok(super::LlmGenerateResponse::Json(serde_json::json!(
resp_json["candidates"][0]
)));
}

let text = match &mut resp_json["candidates"][0]["content"]["parts"][0]["text"] {
Value::String(s) => std::mem::take(s),
_ => bail!("No text in response"),
};

Ok(LlmGenerateResponse { text })
Ok(LlmGenerateResponse::Text(text))
}

fn json_schema_options(&self) -> ToJsonSchemaOptions {
Expand Down Expand Up @@ -333,9 +342,12 @@ impl LlmGenerationClient for VertexAiClient {
.set_parts(vec![Part::new().set_text(sys.to_string())])
});

let mut need_json = false;

// Compose generation config
let mut generation_config = None;
if let Some(OutputFormat::JsonSchema { schema, .. }) = &request.output_format {
need_json = true;
let schema_json = serde_json::to_value(schema)?;
generation_config = Some(
GenerationConfig::new()
Expand All @@ -359,6 +371,18 @@ impl LlmGenerationClient for VertexAiClient {

// Call the API
let resp = req.send().await?;

if need_json {
match resp.candidates.into_iter().next() {
Some(resp_json) => {
return Ok(super::LlmGenerateResponse::Json(serde_json::json!(
resp_json
)));
}
None => bail!("No response"),
}
}

// Extract text from response
let Some(Data::Text(text)) = resp
.candidates
Expand All @@ -370,7 +394,7 @@ impl LlmGenerationClient for VertexAiClient {
else {
bail!("No text in response");
};
Ok(super::LlmGenerateResponse { text })
Ok(super::LlmGenerateResponse::Text(text))
}

fn json_schema_options(&self) -> ToJsonSchemaOptions {
Expand Down
5 changes: 3 additions & 2 deletions rust/cocoindex/src/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ pub struct LlmGenerateRequest<'a> {
}

#[derive(Debug)]
pub struct LlmGenerateResponse {
pub text: String,
pub enum LlmGenerateResponse {
Text(String),
Json(serde_json::Value),
}

#[async_trait]
Expand Down
6 changes: 2 additions & 4 deletions rust/cocoindex/src/llm/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,8 @@ impl LlmGenerationClient for Client {
})
.await
.context("Ollama API error")?;
let json: OllamaResponse = res.json().await?;
Ok(super::LlmGenerateResponse {
text: json.response,
})

Ok(super::LlmGenerateResponse::Json(res.json().await?))
}

fn json_schema_options(&self) -> super::ToJsonSchemaOptions {
Expand Down
32 changes: 23 additions & 9 deletions rust/cocoindex/src/llm/openai.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::prelude::*;
use crate::{llm::OutputFormat, prelude::*};
use base64::prelude::*;

use super::{LlmEmbeddingClient, LlmGenerationClient, detect_image_mime_type};
Expand Down Expand Up @@ -145,15 +145,29 @@ impl LlmGenerationClient for Client {
)
.await?;

// Extract the response text from the first choice
let text = response
.choices
.into_iter()
.next()
.and_then(|choice| choice.message.content)
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?;
let mut response_iter = response.choices.into_iter();

Ok(super::LlmGenerateResponse { text })
match request.output_format {
Some(OutputFormat::JsonSchema { .. }) => {
// Extract the response json from the first choice
let response_json = serde_json::json!(
response_iter
.next()
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?
);

Ok(super::LlmGenerateResponse::Json(response_json))
}
None => {
// Extract the response text from the first choice
let text = response_iter
.next()
.and_then(|choice| choice.message.content)
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?;

Ok(super::LlmGenerateResponse::Text(text))
}
}
}

fn json_schema_options(&self) -> super::ToJsonSchemaOptions {
Expand Down
5 changes: 4 additions & 1 deletion rust/cocoindex/src/ops/functions/extract_by_llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ impl SimpleFunctionExecutor for Executor {
}),
};
let res = self.client.generate(req).await?;
let json_value: serde_json::Value = utils::deser::from_json_str(res.text.as_str())?;
let json_value = match res {
crate::llm::LlmGenerateResponse::Text(text) => utils::deser::from_json_str(&text)?,
crate::llm::LlmGenerateResponse::Json(value) => value,
};
let value = self.value_extractor.extract_value(json_value)?;
Ok(value)
}
Expand Down