Skip to content

Commit 1395396

Browse files
committed
Allow to pass extra info to table functions
Via `register_table_function_with_extra_info`
1 parent 952943a commit 1395396

File tree

2 files changed

+113
-16
lines changed

2 files changed

+113
-16
lines changed

crates/duckdb/src/vtab/function.rs

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ use std::{
1818
os::raw::c_char,
1919
};
2020

21+
/// Callback function to drop extra info of type T
22+
unsafe extern "C" fn drop_extra_info<T>(ptr: *mut c_void) {
23+
drop(unsafe { Box::from_raw(ptr.cast::<T>()) });
24+
}
25+
2126
/// An interface to store and retrieve data during the function bind stage
2227
#[derive(Debug)]
2328
pub struct BindInfo {
@@ -36,6 +41,7 @@ impl BindInfo {
3641
duckdb_bind_add_result_column(self.ptr, c_str.as_ptr() as *const c_char, column_type.ptr);
3742
}
3843
}
44+
3945
/// Report that an error has occurred while calling bind.
4046
///
4147
/// # Arguments
@@ -51,16 +57,15 @@ impl BindInfo {
5157
/// # Arguments
5258
/// * `extra_data`: The bind data object.
5359
/// * `destroy`: The callback that will be called to destroy the bind data (if any)
54-
///
55-
/// # Safety
56-
///
5760
pub unsafe fn set_bind_data(&self, data: *mut c_void, free_function: Option<unsafe extern "C" fn(*mut c_void)>) {
5861
duckdb_bind_set_bind_data(self.ptr, data, free_function);
5962
}
63+
6064
/// Retrieves the number of regular (non-named) parameters to the function.
6165
pub fn get_parameter_count(&self) -> u64 {
6266
unsafe { duckdb_bind_get_parameter_count(self.ptr) }
6367
}
68+
6469
/// Retrieves the parameter at the given index.
6570
///
6671
/// # Arguments
@@ -107,10 +112,7 @@ impl BindInfo {
107112
pub fn set_cardinality(&self, cardinality: idx_t, is_exact: bool) {
108113
unsafe { duckdb_bind_set_cardinality(self.ptr, cardinality, is_exact) }
109114
}
110-
/// Retrieves the extra info of the function as set in [`TableFunction::set_extra_info`]
111-
///
112-
/// # Arguments
113-
/// * `returns`: The extra info
115+
/// Retrieves the extra info of the function as set in [`TableFunction::with_extra_info`].
114116
pub fn get_extra_info<T>(&self) -> *const T {
115117
unsafe { duckdb_bind_get_extra_info(self.ptr).cast() }
116118
}
@@ -162,13 +164,11 @@ impl InitInfo {
162164
indices
163165
}
164166

165-
/// Retrieves the extra info of the function as set in [`TableFunction::set_extra_info`]
166-
///
167-
/// # Arguments
168-
/// * `returns`: The extra info
167+
/// Retrieves the extra info of the function as set in [`TableFunction::with_extra_info`].
169168
pub fn get_extra_info<T>(&self) -> *const T {
170169
unsafe { duckdb_init_get_extra_info(self.0).cast() }
171170
}
171+
172172
/// Gets the bind data set by [`BindInfo::set_bind_data`] during the bind.
173173
///
174174
/// Note that the bind data should be considered as read-only.
@@ -179,13 +179,15 @@ impl InitInfo {
179179
pub fn get_bind_data<T>(&self) -> *const T {
180180
unsafe { duckdb_init_get_bind_data(self.0).cast() }
181181
}
182+
182183
/// Sets how many threads can process this table function in parallel (default: 1)
183184
///
184185
/// # Arguments
185186
/// * `max_threads`: The maximum amount of threads that can process this table function
186187
pub fn set_max_threads(&self, max_threads: idx_t) {
187188
unsafe { duckdb_init_set_max_threads(self.0, max_threads) }
188189
}
190+
189191
/// Report that an error has occurred while calling init.
190192
///
191193
/// # Arguments
@@ -307,15 +309,35 @@ impl TableFunction {
307309

308310
/// Assigns extra information to the table function that can be fetched during binding, etc.
309311
///
312+
/// For most use cases, prefer [`with_extra_info`](Self::with_extra_info) which handles memory management automatically.
313+
///
310314
/// # Arguments
311315
/// * `extra_info`: The extra information
312316
/// * `destroy`: The callback that will be called to destroy the bind data (if any)
313317
///
314318
/// # Safety
319+
/// The caller must ensure that `extra_info` is a valid pointer and that `destroy`
320+
/// properly cleans up the data when called.
315321
pub unsafe fn set_extra_info(&self, extra_info: *mut c_void, destroy: duckdb_delete_callback_t) {
322+
duckdb_table_function_set_extra_info(self.ptr, extra_info, destroy);
323+
}
324+
325+
/// Assigns extra information to the table function that can be fetched during binding, init, and execution.
326+
///
327+
/// This is a safe wrapper around [`set_extra_info`](Self::set_extra_info) that handles memory management automatically.
328+
///
329+
/// # Arguments
330+
/// * `info`: The extra information to store
331+
pub fn with_extra_info<T>(&self, info: T) -> &Self
332+
where
333+
T: Send + Sync + 'static,
334+
{
316335
unsafe {
317-
duckdb_table_function_set_extra_info(self.ptr, extra_info, destroy);
336+
let boxed = Box::new(info);
337+
let ptr = Box::into_raw(boxed) as *mut c_void;
338+
self.set_extra_info(ptr, Some(drop_extra_info::<T>));
318339
}
340+
self
319341
}
320342

321343
/// Sets the thread-local init function of the table function
@@ -383,13 +405,11 @@ impl<V: VTab> TableFunctionInfo<V> {
383405
}
384406
}
385407

386-
/// Retrieves the extra info of the function as set in [`TableFunction::set_extra_info`]
387-
///
388-
/// # Arguments
389-
/// * `returns`: The extra info
408+
/// Retrieves the extra info of the function as set in [`TableFunction::with_extra_info`].
390409
pub fn get_extra_info<T>(&self) -> *mut T {
391410
unsafe { duckdb_function_get_extra_info(self.ptr).cast() }
392411
}
412+
393413
/// Gets the thread-local init data set by [`InitInfo::set_init_data`] during the local_init.
394414
///
395415
/// # Arguments

crates/duckdb/src/vtab/mod.rs

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,32 @@ impl Connection {
150150
}
151151
self.db.borrow_mut().register_table_function(table_function)
152152
}
153+
154+
/// Register the given TableFunction with custom extra info.
155+
///
156+
/// This allows you to pass extra info that can be accessed during bind, init, and execution
157+
/// via `BindInfo::get_extra_info`, `InitInfo::get_extra_info`, or `TableFunctionInfo::get_extra_info`.
158+
#[inline]
159+
pub fn register_table_function_with_extra_info<T: VTab, E>(&self, name: &str, extra_info: &E) -> Result<()>
160+
where
161+
E: Clone + Send + Sync + 'static,
162+
{
163+
let table_function = TableFunction::default();
164+
table_function
165+
.set_name(name)
166+
.supports_pushdown(T::supports_pushdown())
167+
.set_bind(Some(bind::<T>))
168+
.set_init(Some(init::<T>))
169+
.set_function(Some(func::<T>))
170+
.with_extra_info(extra_info.clone());
171+
for ty in T::parameters().unwrap_or_default() {
172+
table_function.add_parameter(&ty);
173+
}
174+
for (name, ty) in T::named_parameters().unwrap_or_default() {
175+
table_function.add_named_parameter(&name, &ty);
176+
}
177+
self.db.borrow_mut().register_table_function(table_function)
178+
}
153179
}
154180

155181
impl InnerConnection {
@@ -287,6 +313,57 @@ mod test {
287313
Ok(())
288314
}
289315

316+
// Test table function with extra info
317+
struct PrefixVTab;
318+
319+
impl VTab for PrefixVTab {
320+
type InitData = HelloInitData;
321+
type BindData = HelloBindData;
322+
323+
fn bind(bind: &BindInfo) -> Result<Self::BindData, Box<dyn Error>> {
324+
bind.add_result_column("column0", LogicalTypeHandle::from(LogicalTypeId::Varchar));
325+
let name = bind.get_parameter(0).to_string();
326+
Ok(HelloBindData { name })
327+
}
328+
329+
fn init(_: &InitInfo) -> Result<Self::InitData, Box<dyn Error>> {
330+
Ok(HelloInitData {
331+
done: AtomicBool::new(false),
332+
})
333+
}
334+
335+
fn func(func: &TableFunctionInfo<Self>, output: &mut DataChunkHandle) -> Result<(), Box<dyn Error>> {
336+
let init_data = func.get_init_data();
337+
let bind_data = func.get_bind_data();
338+
let prefix = unsafe { &*func.get_extra_info::<String>() };
339+
340+
if init_data.done.swap(true, Ordering::Relaxed) {
341+
output.set_len(0);
342+
} else {
343+
let vector = output.flat_vector(0);
344+
let result = CString::new(format!("{prefix} {}", bind_data.name))?;
345+
vector.insert(0, result);
346+
output.set_len(1);
347+
}
348+
Ok(())
349+
}
350+
351+
fn parameters() -> Option<Vec<LogicalTypeHandle>> {
352+
Some(vec![LogicalTypeHandle::from(LogicalTypeId::Varchar)])
353+
}
354+
}
355+
356+
#[test]
357+
fn test_table_function_with_extra_info() -> Result<(), Box<dyn Error>> {
358+
let conn = Connection::open_in_memory()?;
359+
conn.register_table_function_with_extra_info::<PrefixVTab, _>("greet", &"Howdy".to_string())?;
360+
361+
let val = conn.query_row("select * from greet('partner')", [], |row| <(String,)>::try_from(row))?;
362+
assert_eq!(val, ("Howdy partner".to_string(),));
363+
364+
Ok(())
365+
}
366+
290367
#[cfg(feature = "vtab-loadable")]
291368
use duckdb_loadable_macros::duckdb_entrypoint;
292369

0 commit comments

Comments
 (0)