Skip to content

Commit 2576d4d

Browse files
committed
Add TVF tests
1 parent f98ae60 commit 2576d4d

File tree

6 files changed

+154
-20
lines changed

6 files changed

+154
-20
lines changed

DuckDB.NET.Bindings/Bindings.csproj

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,5 +47,8 @@ Updated to DuckDB v1.1.2
4747
<ItemGroup>
4848
<PackageReference Include="GitVersion.MsBuild" Version="5.11.1" PrivateAssets="all" />
4949
</ItemGroup>
50+
<ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.0'">
51+
<PackageReference Include="System.Runtime.CompilerServices.Unsafe" Version="6.0.0" />
52+
</ItemGroup>
5053
<!-- End native lib section -->
5154
</Project>

DuckDB.NET.Bindings/DuckDBWrapperObjects.cs

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using Microsoft.Win32.SafeHandles;
22
using System;
3+
using System.Numerics;
4+
using System.Runtime.CompilerServices;
35

46
namespace DuckDB.NET.Native;
57

@@ -105,14 +107,49 @@ protected override bool ReleaseHandle()
105107
}
106108

107109
internal void SetChildValues(DuckDBValue[] values)
108-
{
110+
{
109111
childValues = values;
110112
}
111113

112114
public T GetValue<T>()
113115
{
114-
var value = NativeMethods.Value.DuckDBGetInt32(this);
115-
return (T)(object)value;
116-
//return Unsafe.As<int, T>(ref value);
116+
var type = typeof(T);
117+
var logicalType = NativeMethods.Value.DuckDBGetValueType(this);
118+
var duckDBType = NativeMethods.LogicalType.DuckDBGetTypeId(logicalType);
119+
120+
return duckDBType switch
121+
{
122+
DuckDBType.Boolean => ReadValue<bool>(NativeMethods.Value.DuckDBGetBool(this)),
123+
124+
DuckDBType.TinyInt => ReadValue<sbyte>(NativeMethods.Value.DuckDBGetInt8(this)),
125+
DuckDBType.SmallInt => ReadValue<short>(NativeMethods.Value.DuckDBGetInt16(this)),
126+
DuckDBType.Integer => ReadValue<int>(NativeMethods.Value.DuckDBGetInt32(this)),
127+
DuckDBType.BigInt => ReadValue<long>(NativeMethods.Value.DuckDBGetInt64(this)),
128+
129+
DuckDBType.UnsignedTinyInt => ReadValue<byte>(NativeMethods.Value.DuckDBGetUInt8(this)),
130+
DuckDBType.UnsignedSmallInt => ReadValue<ushort>(NativeMethods.Value.DuckDBGetUInt16(this)),
131+
DuckDBType.UnsignedInteger => ReadValue<uint>(NativeMethods.Value.DuckDBGetUInt32(this)),
132+
DuckDBType.UnsignedBigInt => ReadValue<ulong>(NativeMethods.Value.DuckDBGetUInt64(this)),
133+
134+
DuckDBType.Float => ReadValue<float>(NativeMethods.Value.DuckDBGetFloat(this)),
135+
DuckDBType.Double => ReadValue<double>(NativeMethods.Value.DuckDBGetDouble(this)),
136+
137+
//DuckDBType.Timestamp => ReadValue<T>(),
138+
//DuckDBType.Interval => ReadValue<T>(),
139+
//DuckDBType.Date => ReadValue<T>(),
140+
//DuckDBType.Time => ReadValue<T>(),
141+
//DuckDBType.TimeTz => ReadValue<T>(),
142+
//DuckDBType.HugeInt => ReadValue<DuckDBHugeInt>(NativeMethods.Value.DuckDBGetHugeInt(this)),
143+
//DuckDBType.UnsignedHugeInt => ReadValue<T>(),
144+
DuckDBType.Varchar => ReadValue<string>(NativeMethods.Value.DuckDBGetVarchar(this)),
145+
//DuckDBType.Decimal => ReadValue<T>(),
146+
//DuckDBType.Uuid => expr,
147+
_ => throw new NotImplementedException($"Cannot read value of type {type.FullName}")
148+
};
149+
150+
T ReadValue<TSource>(TSource value)
151+
{
152+
return Unsafe.As<TSource, T>(ref value);
153+
}
117154
}
118-
}
155+
}

DuckDB.NET.Bindings/NativeMethods/NativeMethods.Value.cs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,50 @@ public static class Value
7171
[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_create_blob")]
7272
public static extern DuckDBValue DuckDBCreateBlob([In] byte[] value, long length);
7373

74+
[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_get_bool")]
75+
public static extern bool DuckDBGetBool(DuckDBValue value);
76+
77+
[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_get_int8")]
78+
public static extern sbyte DuckDBGetInt8(DuckDBValue value);
79+
80+
[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_get_uint8")]
81+
public static extern byte DuckDBGetUInt8(DuckDBValue value);
82+
83+
[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_get_int16")]
84+
public static extern short DuckDBGetInt16(DuckDBValue value);
85+
86+
[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_get_uint16")]
87+
public static extern ushort DuckDBGetUInt16(DuckDBValue value);
88+
7489
[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_get_int32")]
7590
public static extern int DuckDBGetInt32(DuckDBValue value);
91+
92+
[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_get_uint32")]
93+
public static extern uint DuckDBGetUInt32(DuckDBValue value);
94+
95+
[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_get_int64")]
96+
public static extern long DuckDBGetInt64(DuckDBValue value);
97+
98+
[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_get_uint64")]
99+
public static extern ulong DuckDBGetUInt64(DuckDBValue value);
100+
101+
[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_get_hugeint")]
102+
public static extern DuckDBHugeInt DuckDBGetHugeInt(DuckDBValue value);
103+
104+
[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_get_uhugeint")]
105+
public static extern DuckDBUHugeInt DuckDBGetUHugeInt(DuckDBValue value);
106+
107+
[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_get_float")]
108+
public static extern float DuckDBGetFloat(DuckDBValue value);
109+
110+
[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_get_double")]
111+
public static extern double DuckDBGetDouble(DuckDBValue value);
112+
113+
[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_get_value_type")]
114+
public static extern unsafe DuckDBLogicalType DuckDBGetValueType(DuckDBValue value);
115+
116+
[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_get_varchar")]
117+
public static extern string DuckDBGetVarchar(DuckDBValue value);
76118

77119
[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_create_list_value")]
78120
public static extern DuckDBValue DuckDBCreateListValue(DuckDBLogicalType logicalType, IntPtr[] values, long count);

DuckDB.NET.Data/DuckDBConnection.TableFunction.cs

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,26 @@
11
using DuckDB.NET.Data.Extensions;
22
using DuckDB.NET.Data.Internal;
3+
using DuckDB.NET.Data.Internal.Writer;
4+
using DuckDB.NET.Data.Writer;
35
using DuckDB.NET.Native;
46
using System;
57
using System.Collections;
68
using System.Collections.Generic;
79
using System.Diagnostics.CodeAnalysis;
8-
using System.Linq;
910
using System.Runtime.CompilerServices;
1011
using System.Runtime.InteropServices;
11-
using DuckDB.NET.Data.Internal.Writer;
12-
using DuckDB.NET.Data.Writer;
1312

1413
namespace DuckDB.NET.Data;
1514

16-
public record ColumnInfo(string Name, Type Type)
17-
{
18-
}
15+
public record ColumnInfo(string Name, Type Type);
1916

20-
public record TableFunction(IReadOnlyList<ColumnInfo> Columns, IEnumerable Data)
21-
{
22-
}
17+
public record TableFunction(IReadOnlyList<ColumnInfo> Columns, IEnumerable Data);
2318

2419
partial class DuckDBConnection
2520
{
2621
#if NET8_0_OR_GREATER
2722
[Experimental("DuckDBNET001")]
28-
public unsafe void RegisterTableFunction<T>(string name, Func<IEnumerable<IDuckDBValueReader>, TableFunction> resultCallback, Action<object?, IDuckDBDataWriter[]> mapperCallback)
23+
public unsafe void RegisterTableFunction<T>(string name, Func<IEnumerable<IDuckDBValueReader>, TableFunction> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback)
2924
{
3025
var function = NativeMethods.TableFunction.DuckDBCreateTableFunction();
3126
NativeMethods.TableFunction.DuckDBTableFunctionSetName(function, name.ToUnmanagedString());
@@ -89,10 +84,10 @@ public static unsafe void Bind(IntPtr info)
8984
}
9085

9186
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
92-
public static unsafe void Init(IntPtr info) { }
87+
public static void Init(IntPtr info) { }
9388

9489
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
95-
public static unsafe void TableFunction(IntPtr info, IntPtr chunk)
90+
public static void TableFunction(IntPtr info, IntPtr chunk)
9691
{
9792
var bindData = GCHandle.FromIntPtr(NativeMethods.TableFunction.DuckDBFunctionGetBindData(info));
9893
var extraInfo = GCHandle.FromIntPtr(NativeMethods.TableFunction.DuckDBFunctionGetExtraInfo(info));
@@ -125,7 +120,7 @@ public static unsafe void TableFunction(IntPtr info, IntPtr chunk)
125120
{
126121
if (tableFunctionBindData.DataEnumerator.MoveNext())
127122
{
128-
tableFunctionInfo.Mapper(tableFunctionBindData.DataEnumerator.Current, writers);
123+
tableFunctionInfo.Mapper(tableFunctionBindData.DataEnumerator.Current, writers, size);
129124
}
130125
else
131126
{

DuckDB.NET.Data/Internal/TableFunctionInfo.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66

77
namespace DuckDB.NET.Data.Internal;
88

9-
class TableFunctionInfo(Func<IEnumerable<IDuckDBValueReader>, TableFunction> bind, Action<object?, VectorDataWriterBase[]> mapper)
9+
class TableFunctionInfo(Func<IEnumerable<IDuckDBValueReader>, TableFunction> bind, Action<object?, VectorDataWriterBase[], ulong> mapper)
1010
{
1111
public Func<IEnumerable<IDuckDBValueReader>, TableFunction> Bind { get; private set; } = bind;
12-
public Action<object?, VectorDataWriterBase[]> Mapper { get; private set; } = mapper;
12+
public Action<object?, VectorDataWriterBase[], ulong> Mapper { get; private set; } = mapper;
1313
}
1414

1515
class TableFunctionBindData(IReadOnlyList<ColumnInfo> columns, IEnumerator dataEnumerator) : IDisposable
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
using System.Collections.Generic;
2+
using System.Diagnostics.CodeAnalysis;
3+
using System.Linq;
4+
using Dapper;
5+
using DuckDB.NET.Data;
6+
using FluentAssertions;
7+
using Xunit;
8+
9+
namespace DuckDB.NET.Test;
10+
11+
[Experimental("DuckDBNET001")]
12+
public class TableFunctionTests(DuckDBDatabaseFixture db) : DuckDBTestBase(db)
13+
{
14+
[Fact]
15+
public void RegisterTableFunctionWithOneParameter()
16+
{
17+
Connection.RegisterTableFunction<int>("demo", (parameters) =>
18+
{
19+
var value = parameters.ElementAt(0).GetValue<int>();
20+
21+
return new TableFunction(new List<ColumnInfo>()
22+
{
23+
new ColumnInfo("foo", typeof(int)),
24+
}, Enumerable.Range(0, value));
25+
}, (item, writers, rowIndex) =>
26+
{
27+
writers[0].WriteValue((int)item, rowIndex);
28+
});
29+
30+
var data = Connection.Query<int>("SELECT * FROM demo(30);");
31+
data.Should().BeEquivalentTo(Enumerable.Range(0, 30));
32+
}
33+
34+
[Fact]
35+
public void RegisterTableFunctionWithOneParameterTwoColumns()
36+
{
37+
Connection.RegisterTableFunction<int>("demo", (parameters) =>
38+
{
39+
var value = parameters.ElementAt(0).GetValue<int>();
40+
41+
return new TableFunction(new List<ColumnInfo>()
42+
{
43+
new ColumnInfo("foo", typeof(int)),
44+
new ColumnInfo("bar", typeof(string)),
45+
}, Enumerable.Range(0, value));
46+
}, (item, writers, rowIndex) =>
47+
{
48+
writers[0].WriteValue((int)item, rowIndex);
49+
writers[1].WriteValue($"string{item}", rowIndex);
50+
});
51+
52+
var data = Connection.Query<(int, string)>("SELECT * FROM demo(3000);").ToList();
53+
54+
data.Select(tuple => tuple.Item1).Should().BeEquivalentTo(Enumerable.Range(0, 3000));
55+
data.Select(tuple => tuple.Item2).Should().BeEquivalentTo(Enumerable.Range(0, 3000).Select(i => $"string{i}"));
56+
}
57+
}

0 commit comments

Comments
 (0)