@@ -22,7 +22,7 @@ def __check_stencil(stencil: np.ndarray):
2222 return stencil .astype (np .float32 )
2323
2424
25- def __check_method (method : str , stencil : np .ndarray ):
25+ def __check_method (method : str , stencil : np .ndarray , rescale_stencil : bool ):
2626 if method == 'cuda' :
2727 if not __cuda_build__ :
2828 warn (f'Method \' cuda\' requires pyapr to be built with CUDA support (see installation instructions). '
@@ -33,6 +33,9 @@ def __check_method(method: str, stencil: np.ndarray):
3333 warn (f'Method \' cuda\' currently only supports stencils of shape (3, 3, 3) and (5, 5, 5), '
3434 f'but got { stencil .shape } . Using method \' pencil\' on CPU.' , UserWarning )
3535 method = 'pencil'
36+ if rescale_stencil and stencil .shape == (5 , 5 , 5 ):
37+ warn (f'Method \' cuda\' with option \' rescale_stencil\' currently only supports stencils of shape (3, 3, 3). '
38+ f'Using method \' pencil\' on CPU.' )
3639 return method
3740
3841
@@ -79,8 +82,8 @@ def correlate(apr: APR,
7982 method: str
8083 Method used to apply the operation:
8184
82- - ``'pencil'``: construct isotropic neighborhoods of shape (stencil.shape[0], stencil.shape[1], apr.shape[2])
83- - ``'slice'``: construct isotropic neighborhoods of shape (stencil.shape[0], apr.shape[1], apr.shape[2])
85+ - ``'pencil'``: construct isotropic neighborhoods in a buffer of shape (stencil.shape[0], stencil.shape[1], apr.shape[2])
86+ - ``'slice'``: construct isotropic neighborhoods in a buffer of shape (stencil.shape[0], apr.shape[1], apr.shape[2])
8487 - ``'cuda'``: compute the correlation using the GPU. Requires the package to be built with CUDA support,
8588 and ``stencil`` to have shape (3, 3, 3) or (5, 5, 5).
8689
@@ -94,7 +97,7 @@ def correlate(apr: APR,
9497
9598 _check_input (apr , parts , __allowed_input_types__ )
9699 stencil = __check_stencil (stencil )
97- method = __check_method (method , stencil )
100+ method = __check_method (method , stencil , rescale_stencil )
98101 output = output if isinstance (output , FloatParticles ) else FloatParticles ()
99102
100103 if method == 'pencil' :
@@ -151,8 +154,8 @@ def convolve(apr: APR,
151154 method: str
152155 Method used to apply the operation:
153156
154- - ``'pencil'``: construct isotropic neighborhoods of shape (stencil.shape[0], stencil.shape[1], apr.shape[2])
155- - ``'slice'``: construct isotropic neighborhoods of shape (stencil.shape[0], apr.shape[1], apr.shape[2])
157+ - ``'pencil'``: construct isotropic neighborhoods in a buffer of shape (stencil.shape[0], stencil.shape[1], apr.shape[2])
158+ - ``'slice'``: construct isotropic neighborhoods in a buffer of shape (stencil.shape[0], apr.shape[1], apr.shape[2])
156159 - ``'cuda'``: compute the convolution using the GPU. Requires the package to be built with CUDA support,
157160 and ``stencil`` to have shape (3, 3, 3) or (5, 5, 5).
158161
@@ -166,7 +169,7 @@ def convolve(apr: APR,
166169
167170 _check_input (apr , parts , __allowed_input_types__ )
168171 stencil = np .ascontiguousarray (np .flip (__check_stencil (stencil )))
169- method = __check_method (method , stencil )
172+ method = __check_method (method , stencil , rescale_stencil )
170173 output = output if isinstance (output , FloatParticles ) else FloatParticles ()
171174
172175 if method == 'pencil' :
0 commit comments