Skip to content

Commit e489246

Browse files
authored
Merge pull request #65 from AdaptiveParticles/docs_v1
only allow cuda convolution with rescale_stencil for stencil shape (3, 3, 3)
2 parents 15431a9 + 238fdef commit e489246

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

pyapr/filter/convolution.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __check_stencil(stencil: np.ndarray):
2222
return stencil.astype(np.float32)
2323

2424

25-
def __check_method(method: str, stencil: np.ndarray):
25+
def __check_method(method: str, stencil: np.ndarray, rescale_stencil: bool):
2626
if method == 'cuda':
2727
if not __cuda_build__:
2828
warn(f'Method \'cuda\' requires pyapr to be built with CUDA support (see installation instructions). '
@@ -33,6 +33,9 @@ def __check_method(method: str, stencil: np.ndarray):
3333
warn(f'Method \'cuda\' currently only supports stencils of shape (3, 3, 3) and (5, 5, 5), '
3434
f'but got {stencil.shape}. Using method \'pencil\' on CPU.', UserWarning)
3535
method = 'pencil'
36+
if rescale_stencil and stencil.shape == (5, 5, 5):
37+
warn(f'Method \'cuda\' with option \'rescale_stencil\' currently only supports stencils of shape (3, 3, 3). '
38+
f'Using method \'pencil\' on CPU.')
3639
return method
3740

3841

@@ -79,8 +82,8 @@ def correlate(apr: APR,
7982
method: str
8083
Method used to apply the operation:
8184
82-
- ``'pencil'``: construct isotropic neighborhoods of shape (stencil.shape[0], stencil.shape[1], apr.shape[2])
83-
- ``'slice'``: construct isotropic neighborhoods of shape (stencil.shape[0], apr.shape[1], apr.shape[2])
85+
- ``'pencil'``: construct isotropic neighborhoods in a buffer of shape (stencil.shape[0], stencil.shape[1], apr.shape[2])
86+
- ``'slice'``: construct isotropic neighborhoods in a buffer of shape (stencil.shape[0], apr.shape[1], apr.shape[2])
8487
- ``'cuda'``: compute the correlation using the GPU. Requires the package to be built with CUDA support,
8588
and ``stencil`` to have shape (3, 3, 3) or (5, 5, 5).
8689
@@ -94,7 +97,7 @@ def correlate(apr: APR,
9497

9598
_check_input(apr, parts, __allowed_input_types__)
9699
stencil = __check_stencil(stencil)
97-
method = __check_method(method, stencil)
100+
method = __check_method(method, stencil, rescale_stencil)
98101
output = output if isinstance(output, FloatParticles) else FloatParticles()
99102

100103
if method == 'pencil':
@@ -151,8 +154,8 @@ def convolve(apr: APR,
151154
method: str
152155
Method used to apply the operation:
153156
154-
- ``'pencil'``: construct isotropic neighborhoods of shape (stencil.shape[0], stencil.shape[1], apr.shape[2])
155-
- ``'slice'``: construct isotropic neighborhoods of shape (stencil.shape[0], apr.shape[1], apr.shape[2])
157+
- ``'pencil'``: construct isotropic neighborhoods in a buffer of shape (stencil.shape[0], stencil.shape[1], apr.shape[2])
158+
- ``'slice'``: construct isotropic neighborhoods in a buffer of shape (stencil.shape[0], apr.shape[1], apr.shape[2])
156159
- ``'cuda'``: compute the convolution using the GPU. Requires the package to be built with CUDA support,
157160
and ``stencil`` to have shape (3, 3, 3) or (5, 5, 5).
158161
@@ -166,7 +169,7 @@ def convolve(apr: APR,
166169

167170
_check_input(apr, parts, __allowed_input_types__)
168171
stencil = np.ascontiguousarray(np.flip(__check_stencil(stencil)))
169-
method = __check_method(method, stencil)
172+
method = __check_method(method, stencil, rescale_stencil)
170173
output = output if isinstance(output, FloatParticles) else FloatParticles()
171174

172175
if method == 'pencil':

0 commit comments

Comments
 (0)