@@ -71,7 +71,13 @@ def any_version_dft(cls, opset, ctx, node, is_rfft=True, **kwargs):
7171 # Perform DFT on the last dimension (signal dimension)
7272 # axis = -2 (signal dimension, -1 is reserved for complex representation)
7373 if fft_length_tensor is not None :
74- dft_inputs = [input_expanded .output [0 ], fft_length_tensor ]
74+ # Reshape fft_length to ensure it's a scalar (rank 0 tensor) as required by DFT
75+ empty_shape = ctx .make_const (name = utils .make_name ('empty_shape' ),
76+ np_val = np .array ([], dtype = np .int64 ))
77+ fft_length_scalar = ctx .make_node ('Reshape' ,
78+ [fft_length_tensor , empty_shape .name ],
79+ name = utils .make_name ('fft_length_scalar' ))
80+ dft_inputs = [input_expanded .output [0 ], fft_length_scalar .output [0 ]]
7581 else :
7682 dft_inputs = [input_expanded .output [0 ]]
7783
@@ -447,6 +453,12 @@ def any_version_2d_dft(cls, opset, ctx, node, **kwargs):
447453 input_tensor = node .input [0 ]
448454 fft_length_tensor = node .input [1 ]
449455
456+ # Cast fft_length to int64 to ensure compatibility
457+ fft_length_tensor_int64 = ctx .make_node ('Cast' , [fft_length_tensor ],
458+ attr = {'to' : onnx_pb .TensorProto .INT64 },
459+ name = utils .make_name ('fft_length_cast' ))
460+ fft_length_tensor = fft_length_tensor_int64 .output [0 ]
461+
450462 # Get input properties
451463 input_shape = ctx .get_shape (input_tensor )
452464 input_dtype = ctx .get_dtype (input_tensor )
@@ -495,18 +507,50 @@ def any_version_2d_dft(cls, opset, ctx, node, **kwargs):
495507 width_fft_length = ctx .make_node ('Gather' , [fft_length_tensor , one_const .name ],
496508 attr = {'axis' : 0 }, name = utils .make_name ('width_fft_length' ))
497509
510+ # Reshape fft_length values to ensure they're scalars (rank 0 tensors) as required by DFT
511+ empty_shape = ctx .make_const (name = utils .make_name ('empty_shape' ),
512+ np_val = np .array ([], dtype = np .int64 ))
513+ height_fft_length_scalar = ctx .make_node ('Reshape' ,
514+ [height_fft_length .output [0 ], empty_shape .name ],
515+ name = utils .make_name ('height_fft_length_scalar' ))
516+ width_fft_length_scalar = ctx .make_node ('Reshape' ,
517+ [width_fft_length .output [0 ], empty_shape .name ],
518+ name = utils .make_name ('width_fft_length_scalar' ))
519+
498520 # Perform DFT on the second-to-last dimension (height)
499- # axis = -2 (height dimension)
500- dft_height = ctx .make_node ('DFT' , [current_input , height_fft_length .output [0 ]],
501- attr = {'axis' : - 2 , 'onesided' : 1 },
521+ # Use onesided=0 for intermediate transforms
522+ # axis = -3 (height dimension)
523+ dft_height = ctx .make_node ('DFT' , [current_input , height_fft_length_scalar .output [0 ]],
524+ attr = {'axis' : - 3 , 'onesided' : 0 },
502525 name = utils .make_name ('dft_height' ))
503526
504527 # Perform DFT on the last dimension (width)
505- # axis = -2 (width dimension, which becomes -2 after the previous DFT)
506- dft_result = ctx .make_node ('DFT' , [dft_height .output [0 ], width_fft_length .output [0 ]],
528+ # Use onesided=1 for the second transform to get the one-sided FFT result
529+ # axis = -2 (width dimension)
530+ dft_width = ctx .make_node ('DFT' , [dft_height .output [0 ], width_fft_length_scalar .output [0 ]],
507531 attr = {'axis' : - 2 , 'onesided' : 1 },
508532 name = utils .make_name ('dft_width' ))
509533
534+ # The DFT output has shape [..., height, width//2+1, 2] where last dim is [real, imag]
535+ # We only need the real part (index 0) of the complex dimension
536+ zero_const_slice = ctx .make_const (name = utils .make_name ('zero_slice' ),
537+ np_val = np .array ([0 ], dtype = np .int64 ))
538+ one_const_end = ctx .make_const (name = utils .make_name ('one_const_end' ),
539+ np_val = np .array ([1 ], dtype = np .int64 ))
540+ minus_one_axis = ctx .make_const (name = utils .make_name ('minus_one_axis' ),
541+ np_val = np .array ([- 1 ], dtype = np .int64 ))
542+
543+ # Slice to get real part: dft_width[..., 0:1]
544+ dft_real = ctx .make_node ('Slice' ,
545+ [dft_width .output [0 ], zero_const_slice .name , one_const_end .name , minus_one_axis .name ],
546+ name = utils .make_name ('dft_real_part' ))
547+
548+ # Squeeze the last dimension to remove the complex dimension
549+ dft_result = GraphBuilder (ctx ).make_squeeze (
550+ {'data' : dft_real .output [0 ], 'axes' : [- 1 ]},
551+ name = utils .make_name ('dft_result_squeezed' ),
552+ return_node = True )
553+
510554 # Remove the original node and replace with DFT result
511555 ctx .remove_node (node .name )
512556 ctx .replace_all_inputs (node .output [0 ], dft_result .output [0 ])
0 commit comments