@@ -40,6 +40,55 @@ class CommonFFTOp:
4040 onnx_pb .TensorProto .COMPLEX128 ,
4141 ]
4242
43+ @classmethod
44+ def any_version_dft (cls , opset , ctx , node , is_rfft = True , ** kwargs ):
45+ """
46+ Implementation using ONNX DFT operator (opset 17+).
47+ Much simpler than the manual implementation.
48+ """
49+ # Get input tensor and optional fft_length
50+ input_tensor = node .input [0 ]
51+ fft_length_tensor = node .input [1 ] if len (node .input ) > 1 else None
52+
53+ # Get input properties
54+ input_dtype = ctx .get_dtype (input_tensor )
55+
56+ # Validate input
57+ utils .make_sure (input_dtype in CommonFFTOp .supported_dtypes ,
58+ "Unsupported input type for FFT/RFFT." )
59+
60+ # For real inputs, we need to add a dimension for complex representation
61+ # Input shape: [..., signal_length] -> [..., signal_length, 1]
62+ ones_const = ctx .make_const (name = utils .make_name ('ones_const' ),
63+ np_val = np .array ([1 ], dtype = np .int64 ))
64+ input_shape_node = ctx .make_node ('Shape' , [input_tensor ],
65+ name = utils .make_name ('input_shape' ))
66+ new_shape = ctx .make_node ('Concat' , [input_shape_node .output [0 ], ones_const .name ],
67+ attr = {'axis' : 0 }, name = utils .make_name ('new_shape' ))
68+ input_expanded = ctx .make_node ('Reshape' , [input_tensor , new_shape .output [0 ]],
69+ name = utils .make_name ('input_expanded' ))
70+
71+ # Perform DFT on the last dimension (signal dimension)
72+ # axis = -2 (signal dimension, -1 is reserved for complex representation)
73+ if fft_length_tensor is not None :
74+ dft_inputs = [input_expanded .output [0 ], fft_length_tensor ]
75+ else :
76+ dft_inputs = [input_expanded .output [0 ]]
77+
78+ dft_attrs = {'axis' : - 2 }
79+ if is_rfft :
80+ dft_attrs ['onesided' ] = 1
81+
82+ dft_result = ctx .make_node ('DFT' , dft_inputs ,
83+ attr = dft_attrs ,
84+ name = utils .make_name ('dft_result' ))
85+
86+ # Remove the original node and replace with DFT result
87+ ctx .remove_node (node .name )
88+ ctx .replace_all_inputs (node .output [0 ], dft_result .output [0 ])
89+
90+ return dft_result
91+
4392 @classmethod
4493 def any_version (cls , const_length , opset , ctx , node , axis = None ,
4594 fft_length = None , dim = None , onnx_dtype = None , shape = None ,
@@ -303,20 +352,24 @@ def DFT_real(x, fft_length=None):
303352 "MatMul" , inputs = [onx_real_imag_part_name , trx ],
304353 name = utils .make_name ('CPLX_M_' + node_name + 'rfft' ))
305354
306- if not const_length or (axis == 1 or (fft_length < shape_n and axis != 0 )):
307- size = fft_length // 2 + 1
308- if opset >= 10 :
309- cst_axis = ctx .make_const (
310- name = utils .make_name ('CPLX_csta' ), np_val = np .array ([- 2 ], dtype = np .int64 ))
311- cst_length = ctx .make_const (
312- name = utils .make_name ('CPLX_cstl' ), np_val = np .array ([size ], dtype = np .int64 ))
313- sliced_mult = ctx .make_node (
314- "Slice" , inputs = [mult .output [0 ], zero .name , cst_length .name , cst_axis .name ],
315- name = utils .make_name ('CPLX_S2_' + node_name + 'rfft' ))
355+ if not const_length or (axis == 1 or (fft_length is not None and fft_length < shape_n and axis != 0 )):
356+ if fft_length is not None :
357+ size = fft_length // 2 + 1
358+ if opset >= 10 :
359+ cst_axis = ctx .make_const (
360+ name = utils .make_name ('CPLX_csta' ), np_val = np .array ([- 2 ], dtype = np .int64 ))
361+ cst_length = ctx .make_const (
362+ name = utils .make_name ('CPLX_cstl' ), np_val = np .array ([size ], dtype = np .int64 ))
363+ sliced_mult = ctx .make_node (
364+ "Slice" , inputs = [mult .output [0 ], zero .name , cst_length .name , cst_axis .name ],
365+ name = utils .make_name ('CPLX_S2_' + node_name + 'rfft' ))
366+ else :
367+ sliced_mult = ctx .make_node (
368+ "Slice" , inputs = [mult .output [0 ]], attr = dict (starts = [0 ], ends = [size ], axes = [- 2 ]),
369+ name = utils .make_name ('CPLX_S2_' + node_name + 'rfft' ))
316370 else :
317- sliced_mult = ctx .make_node (
318- "Slice" , inputs = [mult .output [0 ]], attr = dict (starts = [0 ], ends = [size ], axes = [- 2 ]),
319- name = utils .make_name ('CPLX_S2_' + node_name + 'rfft' ))
371+ # Dynamic case - use shape inference to determine size
372+ sliced_mult = mult
320373 else :
321374 sliced_mult = mult
322375
@@ -358,6 +411,11 @@ def version_13(cls, ctx, node, **kwargs):
358411 # Unsqueeze changed in opset 13.
359412 return cls .any_version (True , 13 , ctx , node , ** kwargs )
360413
414+ @classmethod
415+ def version_17 (cls , ctx , node , ** kwargs ):
416+ # DFT operator available in opset 17+, use native ONNX DFT implementation
417+ return cls .any_version_dft (17 , ctx , node , is_rfft = True , ** kwargs )
418+
361419
362420@tf_op ("FFT" )
363421class FFTOp (CommonFFTOp ):
@@ -371,9 +429,90 @@ def version_1(cls, ctx, node, **kwargs):
371429 def version_13 (cls , ctx , node , ** kwargs ):
372430 return cls .any_version (False , 13 , ctx , node , ** kwargs )
373431
432+ @classmethod
433+ def version_17 (cls , ctx , node , ** kwargs ):
434+ # DFT operator available in opset 17+, use native ONNX DFT implementation
435+ return cls .any_version_dft (17 , ctx , node , is_rfft = False , ** kwargs )
436+
374437
375438class CommonFFT2DOp (CommonFFTOp ):
376439
440+ @classmethod
441+ def any_version_2d_dft (cls , opset , ctx , node , ** kwargs ):
442+ """
443+ Implementation using ONNX DFT operator (opset 17+).
444+ This is much simpler and more efficient than the manual implementation.
445+ """
446+ # Get input tensor and fft_length
447+ input_tensor = node .input [0 ]
448+ fft_length_tensor = node .input [1 ]
449+
450+ # Get input properties
451+ input_shape = ctx .get_shape (input_tensor )
452+ input_dtype = ctx .get_dtype (input_tensor )
453+
454+ # DFT-based implementation is more flexible and can work with various consumers
455+ # No need to restrict consumer types like the manual implementation
456+
457+ # Validate input
458+ utils .make_sure (input_dtype in CommonFFT2DOp .supported_dtypes ,
459+ "Unsupported input type for RFFT2D." )
460+
461+ # For RFFT2D, we need to perform DFT on the last two dimensions
462+ # First, we need to add a dimension for complex representation if input is real
463+ if input_shape is not None and len (input_shape ) > 0 :
464+ # Add dimension for real/imaginary parts if input is real
465+ # Input shape: [..., height, width] -> [..., height, width, 1]
466+ ones_const = ctx .make_const (name = utils .make_name ('ones_const' ),
467+ np_val = np .array ([1 ], dtype = np .int64 ))
468+ input_shape_node = ctx .make_node ('Shape' , [input_tensor ],
469+ name = utils .make_name ('input_shape' ))
470+ new_shape = ctx .make_node ('Concat' , [input_shape_node .output [0 ], ones_const .name ],
471+ attr = {'axis' : 0 }, name = utils .make_name ('new_shape' ))
472+ input_expanded = ctx .make_node ('Reshape' , [input_tensor , new_shape .output [0 ]],
473+ name = utils .make_name ('input_expanded' ))
474+ current_input = input_expanded .output [0 ]
475+ else :
476+ # Dynamic shape case
477+ input_shape_node = ctx .make_node ('Shape' , [input_tensor ],
478+ name = utils .make_name ('input_shape' ))
479+ ones_const = ctx .make_const (name = utils .make_name ('ones_const' ),
480+ np_val = np .array ([1 ], dtype = np .int64 ))
481+ new_shape = ctx .make_node ('Concat' , [input_shape_node .output [0 ], ones_const .name ],
482+ attr = {'axis' : 0 }, name = utils .make_name ('new_shape' ))
483+ input_expanded = ctx .make_node ('Reshape' , [input_tensor , new_shape .output [0 ]],
484+ name = utils .make_name ('input_expanded' ))
485+ current_input = input_expanded .output [0 ]
486+
487+ # Extract fft_length for each dimension (assuming fft_length has shape [2])
488+ zero_const = ctx .make_const (name = utils .make_name ('zero_const' ),
489+ np_val = np .array ([0 ], dtype = np .int64 ))
490+ one_const = ctx .make_const (name = utils .make_name ('one_const' ),
491+ np_val = np .array ([1 ], dtype = np .int64 ))
492+
493+ height_fft_length = ctx .make_node ('Gather' , [fft_length_tensor , zero_const .name ],
494+ attr = {'axis' : 0 }, name = utils .make_name ('height_fft_length' ))
495+ width_fft_length = ctx .make_node ('Gather' , [fft_length_tensor , one_const .name ],
496+ attr = {'axis' : 0 }, name = utils .make_name ('width_fft_length' ))
497+
498+ # Perform DFT on the second-to-last dimension (height)
499+ # axis = -2 (height dimension)
500+ dft_height = ctx .make_node ('DFT' , [current_input , height_fft_length .output [0 ]],
501+ attr = {'axis' : - 2 , 'onesided' : 1 },
502+ name = utils .make_name ('dft_height' ))
503+
504+ # Perform DFT on the last dimension (width)
505+ # axis = -2 (width dimension, which becomes -2 after the previous DFT)
506+ dft_result = ctx .make_node ('DFT' , [dft_height .output [0 ], width_fft_length .output [0 ]],
507+ attr = {'axis' : - 2 , 'onesided' : 1 },
508+ name = utils .make_name ('dft_width' ))
509+
510+ # Remove the original node and replace with DFT result
511+ ctx .remove_node (node .name )
512+ ctx .replace_all_inputs (node .output [0 ], dft_result .output [0 ])
513+
514+ return dft_result
515+
377516 @classmethod
378517 def any_version_2d (cls , const_length , opset , ctx , node , ** kwargs ):
379518 """
@@ -463,8 +602,8 @@ def onnx_rfft_2d_any_test(x, fft_length):
463602 consumers = ctx .find_output_consumers (node .output [0 ])
464603 consumer_types = set (op .type for op in consumers )
465604 utils .make_sure (
466- consumer_types == {'ComplexAbs' },
467- "Current implementation of RFFT2D only allows ComplexAbs as consumer not %r" ,
605+ consumer_types in [ {'ComplexAbs' }, { 'Squeeze' }] ,
606+ "Current implementation of RFFT2D only allows ComplexAbs or Squeeze as consumer not %r" ,
468607 consumer_types )
469608
470609 oldnode = node
@@ -905,6 +1044,11 @@ def version_13(cls, ctx, node, **kwargs):
9051044 # Unsqueeze changed in opset 13.
9061045 return cls .any_version_2d (True , 13 , ctx , node , ** kwargs )
9071046
1047+ @classmethod
1048+ def version_17 (cls , ctx , node , ** kwargs ):
1049+ # DFT operator available in opset 17+, use native ONNX DFT implementation
1050+ return cls .any_version_2d_dft (17 , ctx , node , ** kwargs )
1051+
9081052
9091053@tf_op ("ComplexAbs" )
9101054class ComplexAbsOp :
0 commit comments