Skip to content

Commit 0f7fb41

Browse files
authored
Add test that returns nb::ndarray<nb::memview> (#1181)
1 parent 4c356d2 commit 0f7fb41

File tree

3 files changed

+79
-19
lines changed

3 files changed

+79
-19
lines changed

tests/test_ndarray.cpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,8 @@ NB_MODULE(test_ndarray_ext, m) {
240240
});
241241

242242
m.def("destruct_count", []() { return destruct_count; });
243-
m.def("return_dlpack", []() {
243+
244+
m.def("return_no_framework", []() {
244245
float *f = new float[8] { 1, 2, 3, 4, 5, 6, 7, 8 };
245246
size_t shape[2] = { 2, 4 };
246247

@@ -299,16 +300,28 @@ NB_MODULE(test_ndarray_ext, m) {
299300
deleter);
300301
});
301302

303+
m.def("ret_memview", []() {
304+
double *d = new double[8] { 1, 2, 3, 4, 5, 6, 7, 8 };
305+
size_t shape[2] = { 2, 4 };
306+
307+
nb::capsule deleter(d, [](void *data) noexcept {
308+
destruct_count++;
309+
delete[] (double *) data;
310+
});
311+
312+
return nb::ndarray<nb::memview, double, nb::shape<2, 4>>(d, 2, shape,
313+
deleter);
314+
});
315+
302316
m.def("ret_array_scalar", []() {
303-
float* f = new float[1] { 1 };
304-
size_t shape[1] = {};
317+
float* f = new float{ 1.0f };
305318

306319
nb::capsule deleter(f, [](void* data) noexcept {
307320
destruct_count++;
308-
delete[] (float *) data;
321+
delete (float *) data;
309322
});
310323

311-
return nb::ndarray<nb::numpy, float>(f, 0, shape, deleter);
324+
return nb::ndarray<nb::numpy, float>(f, 0, nullptr, deleter);
312325
});
313326

314327
m.def("noop_3d_c_contig",

tests/test_ndarray.py

Lines changed: 58 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pytest
33
import warnings
44
import importlib
5-
from common import collect
5+
from common import collect, skip_on_pypy
66

77
try:
88
import numpy as np
@@ -209,50 +209,77 @@ def test11_implicit_conversion_pytorch():
209209
t.noimplicit(torch.zeros(2, 2, 10, dtype=torch.float32)[:, :, 4])
210210

211211

212-
def test14_destroy_capsule():
212+
@needs_numpy
213+
def test12_process_image():
214+
x = np.arange(120, dtype=np.ubyte).reshape(8, 5, 3)
215+
t.process(x)
216+
assert np.all(x == np.arange(0, 240, 2, dtype=np.ubyte).reshape(8, 5, 3))
217+
218+
219+
def test13_destroy_capsule():
213220
collect()
214221
dc = t.destruct_count()
215-
a = t.return_dlpack()
216-
assert dc == t.destruct_count()
217-
del a
222+
capsule = t.return_no_framework()
223+
assert 'dltensor' in repr(capsule)
224+
assert 'versioned' not in repr(capsule)
225+
assert t.destruct_count() == dc
226+
del capsule
218227
collect()
219228
assert t.destruct_count() - dc == 1
220229

221230

222231
@needs_numpy
223-
def test15_consume_numpy():
232+
def test14_consume_numpy():
224233
collect()
225234
class wrapper:
226235
def __init__(self, value):
227236
self.value = value
228237
def __dlpack__(self):
229238
return self.value
230239
dc = t.destruct_count()
231-
a = t.return_dlpack()
240+
capsule = t.return_no_framework()
232241
if hasattr(np, '_from_dlpack'):
233-
x = np._from_dlpack(wrapper(a))
242+
x = np._from_dlpack(wrapper(capsule))
234243
elif hasattr(np, 'from_dlpack'):
235-
x = np.from_dlpack(wrapper(a))
244+
x = np.from_dlpack(wrapper(capsule))
236245
else:
237246
pytest.skip('your version of numpy is too old')
238247

239-
del a
248+
del capsule
240249
collect()
241250
assert x.shape == (2, 4)
242251
assert np.all(x == [[1, 2, 3, 4], [5, 6, 7, 8]])
243-
assert dc == t.destruct_count()
252+
assert t.destruct_count() == dc
244253
del x
245254
collect()
246255
assert t.destruct_count() - dc == 1
247256

248257

249258
@needs_numpy
250-
def test16_passthrough():
259+
def test15_passthrough_numpy():
251260
a = t.ret_numpy()
252261
b = t.passthrough(a)
253262
assert a is b
254263

255-
a = np.array([1,2,3])
264+
a = np.array([1, 2, 3])
265+
b = t.passthrough(a)
266+
assert a is b
267+
268+
a = None
269+
with pytest.raises(TypeError) as excinfo:
270+
b = t.passthrough(a)
271+
assert 'incompatible function arguments' in str(excinfo.value)
272+
b = t.passthrough_arg_none(a)
273+
assert a is b
274+
275+
276+
@needs_torch
277+
def test16_passthrough_torch():
278+
a = t.ret_pytorch()
279+
b = t.passthrough(a)
280+
assert a is b
281+
282+
a = torch.tensor([1, 2, 3])
256283
b = t.passthrough(a)
257284
assert a is b
258285

@@ -292,6 +319,22 @@ def test18_return_pytorch():
292319
assert t.destruct_count() - dc == 1
293320

294321

322+
@skip_on_pypy
323+
def test19_return_memview():
324+
collect()
325+
dc = t.destruct_count()
326+
x = t.ret_memview()
327+
assert isinstance(x, memoryview)
328+
assert x.itemsize == 8
329+
assert x.ndim == 2
330+
assert x.shape == (2, 4)
331+
assert x.strides == (32, 8) # in bytes
332+
assert x.tolist() == [[1, 2, 3, 4], [5, 6, 7, 8]]
333+
del x
334+
collect()
335+
assert t.destruct_count() - dc == 1
336+
337+
295338
@needs_numpy
296339
def test21_return_array_scalar():
297340
collect()
@@ -510,6 +553,7 @@ def test33_force_contig_numpy():
510553
assert b is not a
511554
assert np.all(b == a)
512555

556+
513557
@needs_torch
514558
@pytest.mark.filterwarnings
515559
def test34_force_contig_pytorch():
@@ -567,6 +611,7 @@ def test36_half():
567611
assert x.shape == (2, 4)
568612
assert np.all(x == [[1, 2, 3, 4], [5, 6, 7, 8]])
569613

614+
570615
@needs_numpy
571616
def test37_cast():
572617
a = t.cast(False)

tests/test_ndarray_ext.pyi.ref

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def process(arg: Annotated[NDArray[numpy.uint8], dict(shape=(None, None, 3), ord
9797

9898
def destruct_count() -> int: ...
9999

100-
def return_dlpack() -> Annotated[NDArray[numpy.float32], dict(shape=(2, 4))]: ...
100+
def return_no_framework() -> Annotated[NDArray[numpy.float32], dict(shape=(2, 4))]: ...
101101

102102
def passthrough(arg: NDArray, /) -> NDArray: ...
103103

@@ -115,6 +115,8 @@ def ret_numpy_const() -> Annotated[NDArray[numpy.float32], dict(shape=(2, 4), wr
115115

116116
def ret_pytorch() -> Annotated[NDArray[numpy.float32], dict(shape=(2, 4))]: ...
117117

118+
def ret_memview() -> memoryview[dtype=float64, shape=(2, 4)]: ...
119+
118120
def ret_array_scalar() -> NDArray[numpy.float32]: ...
119121

120122
def noop_3d_c_contig(arg: Annotated[NDArray[numpy.float32], dict(shape=(None, None, None), order='C')], /) -> None: ...

0 commit comments

Comments
 (0)