Skip to content

Commit 3bc637a

Browse files
committed
Use parallel terminology for registration helpers
1 parent 38cc0ec commit 3bc637a

File tree

5 files changed

+56
-51
lines changed

5 files changed

+56
-51
lines changed

crates/duckdb/src/r2d2.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ impl DuckdbConnectionManager {
9292
#[cfg(feature = "vscalar")]
9393
pub fn register_scalar_function<S: VScalar>(&self, name: &str) -> Result<()>
9494
where
95-
S::State: Debug + Default,
95+
S::ExtraInfo: Debug + Default,
9696
{
9797
let conn = self.connection.lock().unwrap();
9898
conn.register_scalar_function::<S>(name)

crates/duckdb/src/vscalar/arrow.rs

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,13 @@ impl ArrowFunctionSignature {
7373

7474
/// A trait for scalar functions that accept and return arrow types that can be registered with DuckDB
7575
pub trait VArrowScalar: Sized {
76-
/// State set at registration time. Persists for the lifetime of the catalog entry.
76+
/// Extra info set at registration time. Persists for the lifetime of the catalog entry.
7777
/// Shared across worker threads and invocations — must not be modified during execution.
7878
/// Must be `'static` as it is stored in DuckDB and may outlive the current stack frame.
79-
type State: Default + Sized + Send + Sync + 'static;
79+
type ExtraInfo: Default + Sized + Send + Sync + 'static;
8080

8181
/// The actual function that is called by DuckDB
82-
fn invoke(state: &Self::State, input: RecordBatch) -> Result<Arc<dyn Array>, Box<dyn std::error::Error>>;
82+
fn invoke(extra_info: &Self::ExtraInfo, input: RecordBatch) -> Result<Arc<dyn Array>, Box<dyn std::error::Error>>;
8383

8484
/// The possible signatures of the scalar function. These will result in DuckDB scalar function overloads.
8585
/// The invoke method should be able to handle all of these signatures.
@@ -90,14 +90,14 @@ impl<T> VScalar for T
9090
where
9191
T: VArrowScalar,
9292
{
93-
type State = T::State;
93+
type ExtraInfo = T::ExtraInfo;
9494

9595
unsafe fn invoke(
96-
state: &Self::State,
96+
extra_info: &Self::ExtraInfo,
9797
input: &mut DataChunkHandle,
9898
out: &mut dyn WritableVector,
9999
) -> Result<(), Box<dyn std::error::Error>> {
100-
let array = T::invoke(state, data_chunk_to_arrow(input)?)?;
100+
let array = T::invoke(extra_info, data_chunk_to_arrow(input)?)?;
101101
write_arrow_array_to_vector(&array, out)
102102
}
103103

@@ -129,9 +129,9 @@ mod test {
129129
struct HelloScalarArrow {}
130130

131131
impl VArrowScalar for HelloScalarArrow {
132-
type State = ();
132+
type ExtraInfo = ();
133133

134-
fn invoke(_: &Self::State, input: RecordBatch) -> Result<Arc<dyn Array>, Box<dyn std::error::Error>> {
134+
fn invoke(_: &Self::ExtraInfo, input: RecordBatch) -> Result<Arc<dyn Array>, Box<dyn std::error::Error>> {
135135
let name = input.column(0).as_any().downcast_ref::<StringArray>().unwrap();
136136
let result = name.iter().map(|v| format!("Hello {}", v.unwrap())).collect::<Vec<_>>();
137137
Ok(Arc::new(StringArray::from(result)))
@@ -164,9 +164,9 @@ mod test {
164164
struct ArrowMultiplyScalar {}
165165

166166
impl VArrowScalar for ArrowMultiplyScalar {
167-
type State = MockState;
167+
type ExtraInfo = MockState;
168168

169-
fn invoke(_: &Self::State, input: RecordBatch) -> Result<Arc<dyn Array>, Box<dyn std::error::Error>> {
169+
fn invoke(_: &Self::ExtraInfo, input: RecordBatch) -> Result<Arc<dyn Array>, Box<dyn std::error::Error>> {
170170
let a = input
171171
.column(0)
172172
.as_any()
@@ -199,10 +199,13 @@ mod test {
199199
struct ArrowOverloaded {}
200200

201201
impl VArrowScalar for ArrowOverloaded {
202-
type State = MockState;
202+
type ExtraInfo = MockState;
203203

204-
fn invoke(state: &Self::State, input: RecordBatch) -> Result<Arc<dyn Array>, Box<dyn std::error::Error>> {
205-
assert_eq!("some meta", state.info);
204+
fn invoke(
205+
extra_info: &Self::ExtraInfo,
206+
input: RecordBatch,
207+
) -> Result<Arc<dyn Array>, Box<dyn std::error::Error>> {
208+
assert_eq!("some meta", extra_info.info);
206209

207210
let a = input.column(0);
208211
let b = input.column(1);
@@ -336,9 +339,9 @@ mod test {
336339
struct SplitFunction {}
337340

338341
impl VArrowScalar for SplitFunction {
339-
type State = ();
342+
type ExtraInfo = ();
340343

341-
fn invoke(_: &Self::State, input: RecordBatch) -> Result<Arc<dyn Array>, Box<dyn std::error::Error>> {
344+
fn invoke(_: &Self::ExtraInfo, input: RecordBatch) -> Result<Arc<dyn Array>, Box<dyn std::error::Error>> {
342345
let strings = input.column(0).as_any().downcast_ref::<StringArray>().unwrap();
343346

344347
let mut builder = arrow::array::ListBuilder::new(arrow::array::StringBuilder::with_capacity(

crates/duckdb/src/vscalar/mod.rs

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ pub use arrow::{ArrowFunctionSignature, ArrowScalarParams, VArrowScalar};
2323

2424
/// Duckdb scalar function trait
2525
pub trait VScalar: Sized {
26-
/// State set at registration time. Persists for the lifetime of the catalog entry.
26+
/// Extra info set at registration time. Persists for the lifetime of the catalog entry.
2727
/// Shared across worker threads and invocations — must not be modified during execution.
2828
/// Must be `'static` as it is stored in DuckDB and may outlive the current stack frame.
29-
type State: Sized + Send + Sync + 'static;
29+
type ExtraInfo: Sized + Send + Sync + 'static;
3030
/// The actual function
3131
///
3232
/// # Safety
@@ -36,7 +36,7 @@ pub trait VScalar: Sized {
3636
/// - Dereferences multiple raw pointers (`func`).
3737
///
3838
unsafe fn invoke(
39-
state: &Self::State,
39+
extra_info: &Self::ExtraInfo,
4040
input: &mut DataChunkHandle,
4141
output: &mut dyn WritableVector,
4242
) -> Result<(), Box<dyn std::error::Error>>;
@@ -110,7 +110,7 @@ impl From<duckdb_function_info> for ScalarFunctionInfo {
110110
}
111111

112112
impl ScalarFunctionInfo {
113-
pub unsafe fn get_scalar_extra_info<T>(&self) -> &T {
113+
pub unsafe fn get_extra_info<T>(&self) -> &T {
114114
&*(duckdb_scalar_function_get_extra_info(self.0).cast())
115115
}
116116

@@ -126,44 +126,48 @@ where
126126
{
127127
let info = ScalarFunctionInfo::from(info);
128128
let mut input = DataChunkHandle::new_unowned(input);
129-
let result = T::invoke(info.get_scalar_extra_info(), &mut input, &mut output);
129+
let result = T::invoke(info.get_extra_info(), &mut input, &mut output);
130130
if let Err(e) = result {
131131
info.set_error(&e.to_string());
132132
}
133133
}
134134

135135
impl Connection {
136-
/// Register the given ScalarFunction with default state.
136+
/// Register the given ScalarFunction with default extra info.
137137
#[inline]
138138
pub fn register_scalar_function<S: VScalar>(&self, name: &str) -> crate::Result<()>
139139
where
140-
S::State: Default,
140+
S::ExtraInfo: Default,
141141
{
142142
let set = ScalarFunctionSet::new(name);
143143
for signature in S::signatures() {
144144
let scalar_function = ScalarFunction::new(name)?;
145145
signature.register_with_scalar(&scalar_function);
146146
scalar_function.set_function(Some(scalar_func::<S>));
147-
scalar_function.set_extra_info(S::State::default());
147+
scalar_function.set_extra_info(S::ExtraInfo::default());
148148
set.add_function(scalar_function)?;
149149
}
150150
self.db.borrow_mut().register_scalar_function_set(set)
151151
}
152152

153-
/// Register the given ScalarFunction with custom state.
153+
/// Register the given ScalarFunction with custom extra info.
154154
///
155-
/// The state is cloned once per function signature (overload) and stored in DuckDB's catalog.
155+
/// The extra info is cloned once per function signature (overload) and stored in DuckDB's catalog.
156156
#[inline]
157-
pub fn register_scalar_function_with_state<S: VScalar>(&self, name: &str, state: &S::State) -> crate::Result<()>
157+
pub fn register_scalar_function_with_extra_info<S: VScalar>(
158+
&self,
159+
name: &str,
160+
extra_info: &S::ExtraInfo,
161+
) -> crate::Result<()>
158162
where
159-
S::State: Clone,
163+
S::ExtraInfo: Clone,
160164
{
161165
let set = ScalarFunctionSet::new(name);
162166
for signature in S::signatures() {
163167
let scalar_function = ScalarFunction::new(name)?;
164168
signature.register_with_scalar(&scalar_function);
165169
scalar_function.set_function(Some(scalar_func::<S>));
166-
scalar_function.set_extra_info(state.clone());
170+
scalar_function.set_extra_info(extra_info.clone());
167171
set.add_function(scalar_function)?;
168172
}
169173
self.db.borrow_mut().register_scalar_function_set(set)
@@ -196,10 +200,10 @@ mod test {
196200
struct ErrorScalar {}
197201

198202
impl VScalar for ErrorScalar {
199-
type State = ();
203+
type ExtraInfo = ();
200204

201205
unsafe fn invoke(
202-
_: &Self::State,
206+
_: &Self::ExtraInfo,
203207
input: &mut DataChunkHandle,
204208
_: &mut dyn WritableVector,
205209
) -> Result<(), Box<dyn std::error::Error>> {
@@ -234,10 +238,10 @@ mod test {
234238
struct EchoScalar {}
235239

236240
impl VScalar for EchoScalar {
237-
type State = TestState;
241+
type ExtraInfo = TestState;
238242

239243
unsafe fn invoke(
240-
state: &Self::State,
244+
extra_info: &Self::ExtraInfo,
241245
input: &mut DataChunkHandle,
242246
output: &mut dyn WritableVector,
243247
) -> Result<(), Box<dyn std::error::Error>> {
@@ -250,7 +254,7 @@ mod test {
250254
let output = output.flat_vector();
251255

252256
for s in strings {
253-
let res = format!("{}: {}", state.prefix, s.repeat(state.multiplier));
257+
let res = format!("{}: {}", extra_info.prefix, s.repeat(extra_info.multiplier));
254258
output.insert(0, res.as_str());
255259
}
256260
Ok(())
@@ -267,10 +271,10 @@ mod test {
267271
struct Repeat {}
268272

269273
impl VScalar for Repeat {
270-
type State = ();
274+
type ExtraInfo = ();
271275

272276
unsafe fn invoke(
273-
_: &Self::State,
277+
_: &Self::ExtraInfo,
274278
input: &mut DataChunkHandle,
275279
output: &mut dyn WritableVector,
276280
) -> Result<(), Box<dyn std::error::Error>> {
@@ -317,9 +321,9 @@ mod test {
317321
}
318322
}
319323

320-
// Test with custom state
324+
// Test with custom extra info
321325
{
322-
conn.register_scalar_function_with_state::<EchoScalar>(
326+
conn.register_scalar_function_with_extra_info::<EchoScalar>(
323327
"echo2",
324328
&TestState {
325329
multiplier: 5,

crates/duckdb/src/vtab/function.rs

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ impl BindInfo {
111111
pub fn set_cardinality(&self, cardinality: idx_t, is_exact: bool) {
112112
unsafe { duckdb_bind_set_cardinality(self.ptr, cardinality, is_exact) }
113113
}
114-
/// Retrieves the extra info of the function as set in [`TableFunction::with_extra_info`].
114+
/// Retrieves the extra info of the function as set in [`TableFunction::set_extra_info`].
115115
pub fn get_extra_info<T>(&self) -> *const T {
116116
unsafe { duckdb_bind_get_extra_info(self.ptr).cast() }
117117
}
@@ -163,7 +163,7 @@ impl InitInfo {
163163
indices
164164
}
165165

166-
/// Retrieves the extra info of the function as set in [`TableFunction::with_extra_info`].
166+
/// Retrieves the extra info of the function as set in [`TableFunction::set_extra_info`].
167167
pub fn get_extra_info<T>(&self) -> *const T {
168168
unsafe { duckdb_init_get_extra_info(self.0).cast() }
169169
}
@@ -306,35 +306,33 @@ impl TableFunction {
306306
self
307307
}
308308

309-
/// Assigns extra information to the table function that can be fetched during binding, etc.
309+
/// Assigns extra information to the table function using raw pointers.
310310
///
311-
/// For most use cases, prefer [`with_extra_info`](Self::with_extra_info) which handles memory management automatically.
311+
/// For most use cases, prefer [`set_extra_info`](Self::set_extra_info) which handles memory management automatically.
312312
///
313313
/// # Arguments
314-
/// * `extra_info`: The extra information
315-
/// * `destroy`: The callback that will be called to destroy the bind data (if any)
314+
/// * `extra_info`: The extra information as a raw pointer
315+
/// * `destroy`: The callback that will be called to destroy the data (if any)
316316
///
317317
/// # Safety
318318
/// The caller must ensure that `extra_info` is a valid pointer and that `destroy`
319319
/// properly cleans up the data when called.
320-
pub unsafe fn set_extra_info(&self, extra_info: *mut c_void, destroy: duckdb_delete_callback_t) {
320+
pub unsafe fn set_extra_info_raw(&self, extra_info: *mut c_void, destroy: duckdb_delete_callback_t) {
321321
duckdb_table_function_set_extra_info(self.ptr, extra_info, destroy);
322322
}
323323

324324
/// Assigns extra information to the table function that can be fetched during binding, init, and execution.
325325
///
326-
/// This is a safe wrapper around [`set_extra_info`](Self::set_extra_info) that handles memory management automatically.
327-
///
328326
/// # Arguments
329327
/// * `info`: The extra information to store
330-
pub fn with_extra_info<T>(&self, info: T) -> &Self
328+
pub fn set_extra_info<T>(&self, info: T) -> &Self
331329
where
332330
T: Send + Sync + 'static,
333331
{
334332
unsafe {
335333
let boxed = Box::new(info);
336334
let ptr = Box::into_raw(boxed) as *mut c_void;
337-
self.set_extra_info(ptr, Some(drop_boxed::<T>));
335+
self.set_extra_info_raw(ptr, Some(drop_boxed::<T>));
338336
}
339337
self
340338
}
@@ -404,7 +402,7 @@ impl<V: VTab> TableFunctionInfo<V> {
404402
}
405403
}
406404

407-
/// Retrieves the extra info of the function as set in [`TableFunction::with_extra_info`].
405+
/// Retrieves the extra info of the function as set in [`TableFunction::set_extra_info`].
408406
pub fn get_extra_info<T>(&self) -> *mut T {
409407
unsafe { duckdb_function_get_extra_info(self.ptr).cast() }
410408
}

crates/duckdb/src/vtab/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ impl Connection {
169169
.set_bind(Some(bind::<T>))
170170
.set_init(Some(init::<T>))
171171
.set_function(Some(func::<T>))
172-
.with_extra_info(extra_info.clone());
172+
.set_extra_info(extra_info.clone());
173173
for ty in T::parameters().unwrap_or_default() {
174174
table_function.add_parameter(&ty);
175175
}

0 commit comments

Comments
 (0)