3737
3838"""
3939
40+ import math
41+
4042import dpctl .tensor as dpt
4143import dpctl .utils as dpu
4244import numpy
5557from .dpnp_utils .dpnp_utils_reduction import dpnp_wrap_reduction_call
5658from .dpnp_utils .dpnp_utils_statistics import dpnp_cov , dpnp_median
5759
60+ min_ = min # pylint: disable=used-before-assignment
61+
5862__all__ = [
5963 "amax" ,
6064 "amin" ,
@@ -457,16 +461,55 @@ def _get_padding(a_size, v_size, mode):
457461 return l_pad , r_pad
458462
459463
460- def _run_native_sliding_dot_product1d (a , v , l_pad , r_pad ):
464+ def _choose_conv_method (a , v , rdtype ):
465+ assert a .size >= v .size
466+ if rdtype == dpnp .bool :
467+ return "direct"
468+
469+ if v .size < 10 ** 4 or a .size < 10 ** 4 :
470+ return "direct"
471+
472+ if dpnp .issubdtype (rdtype , dpnp .integer ):
473+ max_a = int (dpnp .max (dpnp .abs (a )))
474+ sum_v = int (dpnp .sum (dpnp .abs (v )))
475+ max_value = int (max_a * sum_v )
476+
477+ default_float = dpnp .default_float_type (a .sycl_device )
478+ if max_value > 2 ** numpy .finfo (default_float ).nmant - 1 :
479+ return "direct"
480+
481+ if dpnp .issubdtype (rdtype , dpnp .number ):
482+ return "fft"
483+
484+ raise ValueError (f"Unsupported dtype: { rdtype } " )
485+
486+
487+ def _run_native_sliding_dot_product1d (a , v , l_pad , r_pad , rdtype ):
461488 queue = a .sycl_queue
489+ device = a .sycl_device
490+
491+ supported_types = statistics_ext .sliding_dot_product1d_dtypes ()
492+ supported_dtype = to_supported_dtypes (rdtype , supported_types , device )
462493
463- usm_type = dpu .get_coerced_usm_type ([a .usm_type , v .usm_type ])
464- out_size = l_pad + r_pad + a .size - v .size + 1
494+ if supported_dtype is None :
495+ raise ValueError (
496+ f"function does not support input types "
497+ f"({ a .dtype .name } , { v .dtype .name } ), "
498+ "and the inputs could not be coerced to any "
499+ f"supported types. List of supported types: "
500+ f"{ [st .name for st in supported_types ]} "
501+ )
502+
503+ a_casted = dpnp .asarray (a , dtype = supported_dtype , order = "C" )
504+ v_casted = dpnp .asarray (v , dtype = supported_dtype , order = "C" )
505+
506+ usm_type = dpu .get_coerced_usm_type ([a_casted .usm_type , v_casted .usm_type ])
507+ out_size = l_pad + r_pad + a_casted .size - v_casted .size + 1
465508 # out type is the same as input type
466509 out = dpnp .empty_like (a , shape = out_size , usm_type = usm_type )
467510
468- a_usm = dpnp .get_usm_ndarray (a )
469- v_usm = dpnp .get_usm_ndarray (v )
511+ a_usm = dpnp .get_usm_ndarray (a_casted )
512+ v_usm = dpnp .get_usm_ndarray (v_casted )
470513 out_usm = dpnp .get_usm_ndarray (out )
471514
472515 _manager = dpu .SequentialOrderManager [queue ]
@@ -484,7 +527,30 @@ def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad):
484527 return out
485528
486529
487- def correlate (a , v , mode = "valid" ):
530+ def _convolve_fft (a , v , l_pad , r_pad , rtype ):
531+ assert a .size >= v .size
532+ assert l_pad < v .size
533+
534+ # +1 is needed to avoid circular convolution
535+ padded_size = a .size + r_pad + 1
536+ fft_size = 2 ** math .ceil (math .log2 (padded_size ))
537+
538+ af = dpnp .fft .fft (a , fft_size ) # pylint: disable=no-member
539+ vf = dpnp .fft .fft (v , fft_size ) # pylint: disable=no-member
540+
541+ r = dpnp .fft .ifft (af * vf ) # pylint: disable=no-member
542+ if dpnp .issubdtype (rtype , dpnp .floating ):
543+ r = r .real
544+ elif dpnp .issubdtype (rtype , dpnp .integer ) or rtype == dpnp .bool :
545+ r = r .real .round ()
546+
547+ start = v .size - 1 - l_pad
548+ end = padded_size - 1
549+
550+ return r [start :end ]
551+
552+
553+ def correlate (a , v , mode = "valid" , method = "auto" ):
488554 r"""
489555 Cross-correlation of two 1-dimensional sequences.
490556
@@ -509,6 +575,20 @@ def correlate(a, v, mode="valid"):
509575 is ``"valid"``, unlike :obj:`dpnp.convolve`, which uses ``"full"``.
510576
511577 Default: ``"valid"``.
578+ method : {'auto', 'direct', 'fft'}, optional
579+ `'direct'`: The correlation is determined directly from sums.
580+
581+ `'fft'`: The Fourier Transform is used to perform the calculations.
582+ This method is faster for long sequences but can have accuracy issues.
583+
584+ `'auto'`: Automatically chooses direct or Fourier method based on
585+ an estimate of which is faster.
586+
587+ Note: Use of the FFT convolution on input containing NAN or INF
588+ will lead to the entire output being NAN or INF.
589+ Use method='direct' when your input contains NAN or INF values.
590+
591+ Default: ``'auto'``.
512592
513593 Notes
514594 -----
@@ -576,20 +656,14 @@ def correlate(a, v, mode="valid"):
576656 f"Received shapes: a.shape={ a .shape } , v.shape={ v .shape } "
577657 )
578658
579- supported_types = statistics_ext .sliding_dot_product1d_dtypes ()
659+ supported_methods = ["auto" , "direct" , "fft" ]
660+ if method not in supported_methods :
661+ raise ValueError (
662+ f"Unknown method: { method } . Supported methods: { supported_methods } "
663+ )
580664
581665 device = a .sycl_device
582666 rdtype = result_type_for_device ([a .dtype , v .dtype ], device )
583- supported_dtype = to_supported_dtypes (rdtype , supported_types , device )
584-
585- if supported_dtype is None :
586- raise ValueError (
587- f"function does not support input types "
588- f"({ a .dtype .name } , { v .dtype .name } ), "
589- "and the inputs could not be coerced to any "
590- f"supported types. List of supported types: "
591- f"{ [st .name for st in supported_types ]} "
592- )
593667
594668 if dpnp .issubdtype (v .dtype , dpnp .complexfloating ):
595669 v = dpnp .conj (v )
@@ -601,10 +675,15 @@ def correlate(a, v, mode="valid"):
601675
602676 l_pad , r_pad = _get_padding (a .size , v .size , mode )
603677
604- a_casted = dpnp . asarray ( a , dtype = supported_dtype , order = "C" )
605- v_casted = dpnp . asarray ( v , dtype = supported_dtype , order = "C" )
678+ if method == "auto" :
679+ method = _choose_conv_method ( a , v , rdtype )
606680
607- r = _run_native_sliding_dot_product1d (a_casted , v_casted , l_pad , r_pad )
681+ if method == "direct" :
682+ r = _run_native_sliding_dot_product1d (a , v , l_pad , r_pad , rdtype )
683+ elif method == "fft" :
684+ r = _convolve_fft (a , v [::- 1 ], l_pad , r_pad , rdtype )
685+ else :
686+ raise ValueError (f"Unknown method: { method } " )
608687
609688 if revert :
610689 r = r [::- 1 ]
0 commit comments