Skip to content

Commit 7222612

Browse files
committed
Add overloads for TVFs, add test
1 parent 5faf28b commit 7222612

File tree

5 files changed

+105
-20
lines changed

5 files changed

+105
-20
lines changed

DuckDB.NET.Data/DuckDBConnection.ScalarFunction.cs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using DuckDB.NET.Data.Writer;
77
using DuckDB.NET.Native;
88
using System;
9+
using System.Collections.Generic;
910
using System.Diagnostics.CodeAnalysis;
1011
using System.Runtime.CompilerServices;
1112
using System.Runtime.InteropServices;
@@ -16,27 +17,27 @@ partial class DuckDBConnection
1617
{
1718
#if NET8_0_OR_GREATER
1819
[Experimental("DuckDBNET001")]
19-
public void RegisterScalarFunction<TResult>(string name, Action<IDuckDBDataReader[], IDuckDBDataWriter, ulong> action, bool isPureFunction = false)
20+
public void RegisterScalarFunction<TResult>(string name, Action<IReadOnlyList<IDuckDBDataReader>, IDuckDBDataWriter, ulong> action, bool isPureFunction = false)
2021
{
2122
RegisterScalarMethod(name, action, DuckDBTypeMap.GetLogicalType<TResult>(), varargs: false, !isPureFunction);
2223
}
2324

2425
[Experimental("DuckDBNET001")]
25-
public void RegisterScalarFunction<T, TResult>(string name, Action<IDuckDBDataReader[], IDuckDBDataWriter, ulong> action, bool isPureFunction = true, bool @params = false)
26+
public void RegisterScalarFunction<T, TResult>(string name, Action<IReadOnlyList<IDuckDBDataReader>, IDuckDBDataWriter, ulong> action, bool isPureFunction = true, bool @params = false)
2627
{
2728
RegisterScalarMethod(name, action, DuckDBTypeMap.GetLogicalType<TResult>(), @params, !isPureFunction, DuckDBTypeMap.GetLogicalType<T>());
2829
}
2930

3031
[Experimental("DuckDBNET001")]
31-
public void RegisterScalarFunction<T1, T2, TResult>(string name, Action<IDuckDBDataReader[], IDuckDBDataWriter, ulong> action, bool isPureFunction = true)
32+
public void RegisterScalarFunction<T1, T2, TResult>(string name, Action<IReadOnlyList<IDuckDBDataReader>, IDuckDBDataWriter, ulong> action, bool isPureFunction = true)
3233
{
3334
RegisterScalarMethod(name, action, DuckDBTypeMap.GetLogicalType<TResult>(), varargs: false, !isPureFunction,
3435
DuckDBTypeMap.GetLogicalType<T1>(),
3536
DuckDBTypeMap.GetLogicalType<T2>());
3637
}
3738

3839
[Experimental("DuckDBNET001")]
39-
public void RegisterScalarFunction<T1, T2, T3, TResult>(string name, Action<IDuckDBDataReader[], IDuckDBDataWriter, ulong> action, bool isPureFunction = true)
40+
public void RegisterScalarFunction<T1, T2, T3, TResult>(string name, Action<IReadOnlyList<IDuckDBDataReader>, IDuckDBDataWriter, ulong> action, bool isPureFunction = true)
4041
{
4142
RegisterScalarMethod(name, action, DuckDBTypeMap.GetLogicalType<TResult>(), varargs: false, !isPureFunction,
4243
DuckDBTypeMap.GetLogicalType<T1>(),
@@ -45,7 +46,7 @@ public void RegisterScalarFunction<T1, T2, T3, TResult>(string name, Action<IDuc
4546
}
4647

4748
[Experimental("DuckDBNET001")]
48-
public void RegisterScalarFunction<T1, T2, T3, T4, TResult>(string name, Action<IDuckDBDataReader[], IDuckDBDataWriter, ulong> action, bool isPureFunction = true)
49+
public void RegisterScalarFunction<T1, T2, T3, T4, TResult>(string name, Action<IReadOnlyList<IDuckDBDataReader>, IDuckDBDataWriter, ulong> action, bool isPureFunction = true)
4950
{
5051
RegisterScalarMethod(name, action, DuckDBTypeMap.GetLogicalType<TResult>(), varargs: false, !isPureFunction,
5152
DuckDBTypeMap.GetLogicalType<T1>(),
@@ -55,11 +56,14 @@ public void RegisterScalarFunction<T1, T2, T3, T4, TResult>(string name, Action<
5556
}
5657

5758
[Experimental("DuckDBNET001")]
58-
private unsafe void RegisterScalarMethod(string name, Action<IDuckDBDataReader[], IDuckDBDataWriter, ulong> action, DuckDBLogicalType returnType,
59+
private unsafe void RegisterScalarMethod(string name, Action<IReadOnlyList<IDuckDBDataReader>, IDuckDBDataWriter, ulong> action, DuckDBLogicalType returnType,
5960
bool varargs, bool @volatile, params DuckDBLogicalType[] parameterTypes)
6061
{
6162
var function = NativeMethods.ScalarFunction.DuckDBCreateScalarFunction();
62-
NativeMethods.ScalarFunction.DuckDBScalarFunctionSetName(function, name.ToUnmanagedString());
63+
using (var handle = name.ToUnmanagedString())
64+
{
65+
NativeMethods.ScalarFunction.DuckDBScalarFunctionSetName(function, handle);
66+
}
6367

6468
if (varargs)
6569
{

DuckDB.NET.Data/DuckDBConnection.TableFunction.cs

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,65 @@ partial class DuckDBConnection
2020
{
2121
#if NET8_0_OR_GREATER
2222
[Experimental("DuckDBNET001")]
23-
public unsafe void RegisterTableFunction<T>(string name, Func<IEnumerable<IDuckDBValueReader>, TableFunction> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback)
23+
public void RegisterTableFunction<T>(string name, Func<IReadOnlyList<IDuckDBValueReader>, TableFunction> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback)
24+
{
25+
RegisterTableFunctionInternal(name, resultCallback, mapperCallback, typeof(T));
26+
}
27+
28+
[Experimental("DuckDBNET001")]
29+
public void RegisterTableFunction<T1, T2>(string name, Func<IReadOnlyList<IDuckDBValueReader>, TableFunction> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback)
30+
{
31+
RegisterTableFunctionInternal(name, resultCallback, mapperCallback, typeof(T1), typeof(T2));
32+
}
33+
34+
[Experimental("DuckDBNET001")]
35+
public void RegisterTableFunction<T1, T2, T3>(string name, Func<IReadOnlyList<IDuckDBValueReader>, TableFunction> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback)
36+
{
37+
RegisterTableFunctionInternal(name, resultCallback, mapperCallback, typeof(T1), typeof(T2), typeof(T3));
38+
}
39+
40+
[Experimental("DuckDBNET001")]
41+
public void RegisterTableFunction<T1, T2, T3, T4>(string name, Func<IReadOnlyList<IDuckDBValueReader>, TableFunction> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback)
42+
{
43+
RegisterTableFunctionInternal(name, resultCallback, mapperCallback, typeof(T1), typeof(T2), typeof(T3), typeof(T4));
44+
}
45+
46+
[Experimental("DuckDBNET001")]
47+
public void RegisterTableFunction<T1, T2, T3, T4, T5>(string name, Func<IReadOnlyList<IDuckDBValueReader>, TableFunction> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback)
48+
{
49+
RegisterTableFunctionInternal(name, resultCallback, mapperCallback, typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5));
50+
}
51+
52+
[Experimental("DuckDBNET001")]
53+
public void RegisterTableFunction<T1, T2, T3, T4, T5, T6>(string name, Func<IReadOnlyList<IDuckDBValueReader>, TableFunction> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback)
54+
{
55+
RegisterTableFunctionInternal(name, resultCallback, mapperCallback, typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5), typeof(T6));
56+
}
57+
58+
[Experimental("DuckDBNET001")]
59+
public void RegisterTableFunction<T1, T2, T3, T4, T5, T6, T7>(string name, Func<IReadOnlyList<IDuckDBValueReader>, TableFunction> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback)
60+
{
61+
RegisterTableFunctionInternal(name, resultCallback, mapperCallback, typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5), typeof(T6), typeof(T7));
62+
}
63+
64+
[Experimental("DuckDBNET001")]
65+
public void RegisterTableFunction<T1, T2, T3, T4, T5, T6, T7, T8>(string name, Func<IReadOnlyList<IDuckDBValueReader>, TableFunction> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback)
66+
{
67+
RegisterTableFunctionInternal(name, resultCallback, mapperCallback, typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5), typeof(T6), typeof(T7), typeof(T8));
68+
}
69+
70+
[Experimental("DuckDBNET001")]
71+
private unsafe void RegisterTableFunctionInternal(string name, Func<IReadOnlyList<IDuckDBValueReader>, TableFunction> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback, params Type[] parameterTypes)
2472
{
2573
var function = NativeMethods.TableFunction.DuckDBCreateTableFunction();
26-
NativeMethods.TableFunction.DuckDBTableFunctionSetName(function, name.ToUnmanagedString());
74+
using (var handle = name.ToUnmanagedString())
75+
{
76+
NativeMethods.TableFunction.DuckDBTableFunctionSetName(function, handle);
77+
}
2778

28-
using (var logicalType = DuckDBTypeMap.GetLogicalType<T>())
79+
foreach (var type in parameterTypes)
2980
{
81+
using var logicalType = DuckDBTypeMap.GetLogicalType(type);
3082
NativeMethods.TableFunction.DuckDBTableFunctionAddParameter(function, logicalType);
3183
}
3284

@@ -69,7 +121,7 @@ public static unsafe void Bind(IntPtr info)
69121

70122
foreach (var parameter in parameters)
71123
{
72-
(parameter as IDisposable)?.Dispose();
124+
((DuckDBValue)parameter).Dispose();
73125
}
74126

75127
foreach (var columnInfo in tableFunctionData.Columns)

DuckDB.NET.Data/Internal/TableFunctionInfo.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
namespace DuckDB.NET.Data.Internal;
88

9-
class TableFunctionInfo(Func<IEnumerable<IDuckDBValueReader>, TableFunction> bind, Action<object?, VectorDataWriterBase[], ulong> mapper)
9+
class TableFunctionInfo(Func<IReadOnlyList<IDuckDBValueReader>, TableFunction> bind, Action<object?, VectorDataWriterBase[], ulong> mapper)
1010
{
11-
public Func<IEnumerable<IDuckDBValueReader>, TableFunction> Bind { get; private set; } = bind;
11+
public Func<IReadOnlyList<IDuckDBValueReader>, TableFunction> Bind { get; private set; } = bind;
1212
public Action<object?, VectorDataWriterBase[], ulong> Mapper { get; private set; } = mapper;
1313
}
1414

DuckDB.NET.Test/ScalarFunctionTests.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,17 @@ public void RegisterScalarFunctionWithVarargs()
2222
{
2323
var value = 0L;
2424

25-
if (readers.Length == 0)
25+
if (readers.Count == 0)
2626
{
2727
value = Random.Shared.NextInt64();
2828
}
2929

30-
if (readers.Length == 1)
30+
if (readers.Count == 1)
3131
{
3232
value = Random.Shared.NextInt64(readers[0].GetValue<long>(index));
3333
}
3434

35-
if (readers.Length == 2)
35+
if (readers.Count == 2)
3636
{
3737
value = Random.Shared.NextInt64(readers[0].GetValue<long>(index), readers[1].GetValue<long>(index));
3838
}
@@ -64,7 +64,7 @@ public void RegisterScalarFunctionWithVarargs()
6464
public void RegisterScalarFunctionWithoutParameters()
6565
{
6666
var values = new List<long>();
67-
Connection.RegisterScalarFunction<long>("my_random", (readers, writer, rowCount) =>
67+
Connection.RegisterScalarFunction<long>("my_random", (_, writer, rowCount) =>
6868
{
6969
for (ulong index = 0; index < rowCount; index++)
7070
{

DuckDB.NET.Test/TableFunctionTests.cs

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using Dapper;
1+
using System;
2+
using Dapper;
23
using DuckDB.NET.Data;
34
using FluentAssertions;
45
using System.Collections.Generic;
@@ -16,7 +17,7 @@ public void RegisterTableFunctionWithOneParameter()
1617
{
1718
Connection.RegisterTableFunction<int>("demo", (parameters) =>
1819
{
19-
var value = parameters.ElementAt(0).GetValue<int>();
20+
var value = parameters[0].GetValue<int>();
2021

2122
return new TableFunction(new List<ColumnInfo>()
2223
{
@@ -38,7 +39,7 @@ public void RegisterTableFunctionWithOneParameterTwoColumns()
3839

3940
Connection.RegisterTableFunction<int>("demo2", (parameters) =>
4041
{
41-
var value = parameters.ElementAt(0).GetValue<int>();
42+
var value = parameters[0].GetValue<int>();
4243

4344
return new TableFunction(new List<ColumnInfo>()
4445
{
@@ -56,4 +57,32 @@ public void RegisterTableFunctionWithOneParameterTwoColumns()
5657
data.Select(tuple => tuple.Item1).Should().BeEquivalentTo(Enumerable.Range(0, count));
5758
data.Select(tuple => tuple.Item2).Should().BeEquivalentTo(Enumerable.Range(0, count).Select(i => $"string{i}"));
5859
}
60+
61+
[Fact]
62+
public void RegisterTableFunctionWithTwoParameterTwoColumns()
63+
{
64+
var count = 50;
65+
66+
Connection.RegisterTableFunction<short, string>("demo3", (parameters) =>
67+
{
68+
var start = parameters[0].GetValue<short>();
69+
var prefix = parameters[1].GetValue<string>();
70+
71+
return new TableFunction(new List<ColumnInfo>()
72+
{
73+
new ColumnInfo("foo", typeof(int)),
74+
new ColumnInfo("bar", typeof(string)),
75+
}, Enumerable.Range(start, count).Select(index => KeyValuePair.Create(index, prefix + index)));
76+
}, (item, writers, rowIndex) =>
77+
{
78+
var pair = (KeyValuePair<int, string>)item;
79+
writers[0].WriteValue(pair.Key, rowIndex);
80+
writers[1].WriteValue(pair.Value, rowIndex);
81+
});
82+
83+
var data = Connection.Query<(int, string)>($"SELECT * FROM demo3(30::SmallInt, 'DuckDB');").ToList();
84+
85+
data.Select(tuple => tuple.Item1).Should().BeEquivalentTo(Enumerable.Range(30, count));
86+
data.Select(tuple => tuple.Item2).Should().BeEquivalentTo(Enumerable.Range(30, count).Select(i => $"DuckDB{i}"));
87+
}
5988
}

0 commit comments

Comments
 (0)