|
2 | 2 | import pytest |
3 | 3 | import warnings |
4 | 4 | import importlib |
5 | | -from common import collect |
| 5 | +from common import collect, skip_on_pypy |
6 | 6 |
|
7 | 7 | try: |
8 | 8 | import numpy as np |
@@ -209,50 +209,77 @@ def test11_implicit_conversion_pytorch(): |
209 | 209 | t.noimplicit(torch.zeros(2, 2, 10, dtype=torch.float32)[:, :, 4]) |
210 | 210 |
|
211 | 211 |
|
212 | | -def test14_destroy_capsule(): |
| 212 | +@needs_numpy |
| 213 | +def test12_process_image(): |
| 214 | + x = np.arange(120, dtype=np.ubyte).reshape(8, 5, 3) |
| 215 | + t.process(x) |
| 216 | + assert np.all(x == np.arange(0, 240, 2, dtype=np.ubyte).reshape(8, 5, 3)) |
| 217 | + |
| 218 | + |
| 219 | +def test13_destroy_capsule(): |
213 | 220 | collect() |
214 | 221 | dc = t.destruct_count() |
215 | | - a = t.return_dlpack() |
216 | | - assert dc == t.destruct_count() |
217 | | - del a |
| 222 | + capsule = t.return_no_framework() |
| 223 | + assert 'dltensor' in repr(capsule) |
| 224 | + assert 'versioned' not in repr(capsule) |
| 225 | + assert t.destruct_count() == dc |
| 226 | + del capsule |
218 | 227 | collect() |
219 | 228 | assert t.destruct_count() - dc == 1 |
220 | 229 |
|
221 | 230 |
|
222 | 231 | @needs_numpy |
223 | | -def test15_consume_numpy(): |
| 232 | +def test14_consume_numpy(): |
224 | 233 | collect() |
225 | 234 | class wrapper: |
226 | 235 | def __init__(self, value): |
227 | 236 | self.value = value |
228 | 237 | def __dlpack__(self): |
229 | 238 | return self.value |
230 | 239 | dc = t.destruct_count() |
231 | | - a = t.return_dlpack() |
| 240 | + capsule = t.return_no_framework() |
232 | 241 | if hasattr(np, '_from_dlpack'): |
233 | | - x = np._from_dlpack(wrapper(a)) |
| 242 | + x = np._from_dlpack(wrapper(capsule)) |
234 | 243 | elif hasattr(np, 'from_dlpack'): |
235 | | - x = np.from_dlpack(wrapper(a)) |
| 244 | + x = np.from_dlpack(wrapper(capsule)) |
236 | 245 | else: |
237 | 246 | pytest.skip('your version of numpy is too old') |
238 | 247 |
|
239 | | - del a |
| 248 | + del capsule |
240 | 249 | collect() |
241 | 250 | assert x.shape == (2, 4) |
242 | 251 | assert np.all(x == [[1, 2, 3, 4], [5, 6, 7, 8]]) |
243 | | - assert dc == t.destruct_count() |
| 252 | + assert t.destruct_count() == dc |
244 | 253 | del x |
245 | 254 | collect() |
246 | 255 | assert t.destruct_count() - dc == 1 |
247 | 256 |
|
248 | 257 |
|
249 | 258 | @needs_numpy |
250 | | -def test16_passthrough(): |
| 259 | +def test15_passthrough_numpy(): |
251 | 260 | a = t.ret_numpy() |
252 | 261 | b = t.passthrough(a) |
253 | 262 | assert a is b |
254 | 263 |
|
255 | | - a = np.array([1,2,3]) |
| 264 | + a = np.array([1, 2, 3]) |
| 265 | + b = t.passthrough(a) |
| 266 | + assert a is b |
| 267 | + |
| 268 | + a = None |
| 269 | + with pytest.raises(TypeError) as excinfo: |
| 270 | + b = t.passthrough(a) |
| 271 | + assert 'incompatible function arguments' in str(excinfo.value) |
| 272 | + b = t.passthrough_arg_none(a) |
| 273 | + assert a is b |
| 274 | + |
| 275 | + |
| 276 | +@needs_torch |
| 277 | +def test16_passthrough_torch(): |
| 278 | + a = t.ret_pytorch() |
| 279 | + b = t.passthrough(a) |
| 280 | + assert a is b |
| 281 | + |
| 282 | + a = torch.tensor([1, 2, 3]) |
256 | 283 | b = t.passthrough(a) |
257 | 284 | assert a is b |
258 | 285 |
|
@@ -292,6 +319,22 @@ def test18_return_pytorch(): |
292 | 319 | assert t.destruct_count() - dc == 1 |
293 | 320 |
|
294 | 321 |
|
| 322 | +@skip_on_pypy |
| 323 | +def test19_return_memview(): |
| 324 | + collect() |
| 325 | + dc = t.destruct_count() |
| 326 | + x = t.ret_memview() |
| 327 | + assert isinstance(x, memoryview) |
| 328 | + assert x.itemsize == 8 |
| 329 | + assert x.ndim == 2 |
| 330 | + assert x.shape == (2, 4) |
| 331 | + assert x.strides == (32, 8) # in bytes |
| 332 | + assert x.tolist() == [[1, 2, 3, 4], [5, 6, 7, 8]] |
| 333 | + del x |
| 334 | + collect() |
| 335 | + assert t.destruct_count() - dc == 1 |
| 336 | + |
| 337 | + |
295 | 338 | @needs_numpy |
296 | 339 | def test21_return_array_scalar(): |
297 | 340 | collect() |
@@ -510,6 +553,7 @@ def test33_force_contig_numpy(): |
510 | 553 | assert b is not a |
511 | 554 | assert np.all(b == a) |
512 | 555 |
|
| 556 | + |
513 | 557 | @needs_torch |
514 | 558 | @pytest.mark.filterwarnings |
515 | 559 | def test34_force_contig_pytorch(): |
@@ -567,6 +611,7 @@ def test36_half(): |
567 | 611 | assert x.shape == (2, 4) |
568 | 612 | assert np.all(x == [[1, 2, 3, 4], [5, 6, 7, 8]]) |
569 | 613 |
|
| 614 | + |
570 | 615 | @needs_numpy |
571 | 616 | def test37_cast(): |
572 | 617 | a = t.cast(False) |
|
0 commit comments