Skip to content

Commit 344e9ff

Browse files
committed
Bird
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent f99b289 commit 344e9ff

File tree

1 file changed

+50
-6
lines changed

1 file changed

+50
-6
lines changed

tf2onnx/onnx_opset/signal.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,13 @@ def any_version_dft(cls, opset, ctx, node, is_rfft=True, **kwargs):
7171
# Perform DFT on the last dimension (signal dimension)
7272
# axis = -2 (signal dimension, -1 is reserved for complex representation)
7373
if fft_length_tensor is not None:
74-
dft_inputs = [input_expanded.output[0], fft_length_tensor]
74+
# Reshape fft_length to ensure it's a scalar (rank 0 tensor) as required by DFT
75+
empty_shape = ctx.make_const(name=utils.make_name('empty_shape'),
76+
np_val=np.array([], dtype=np.int64))
77+
fft_length_scalar = ctx.make_node('Reshape',
78+
[fft_length_tensor, empty_shape.name],
79+
name=utils.make_name('fft_length_scalar'))
80+
dft_inputs = [input_expanded.output[0], fft_length_scalar.output[0]]
7581
else:
7682
dft_inputs = [input_expanded.output[0]]
7783

@@ -447,6 +453,12 @@ def any_version_2d_dft(cls, opset, ctx, node, **kwargs):
447453
input_tensor = node.input[0]
448454
fft_length_tensor = node.input[1]
449455

456+
# Cast fft_length to int64 to ensure compatibility
457+
fft_length_tensor_int64 = ctx.make_node('Cast', [fft_length_tensor],
458+
attr={'to': onnx_pb.TensorProto.INT64},
459+
name=utils.make_name('fft_length_cast'))
460+
fft_length_tensor = fft_length_tensor_int64.output[0]
461+
450462
# Get input properties
451463
input_shape = ctx.get_shape(input_tensor)
452464
input_dtype = ctx.get_dtype(input_tensor)
@@ -495,18 +507,50 @@ def any_version_2d_dft(cls, opset, ctx, node, **kwargs):
495507
width_fft_length = ctx.make_node('Gather', [fft_length_tensor, one_const.name],
496508
attr={'axis': 0}, name=utils.make_name('width_fft_length'))
497509

510+
# Reshape fft_length values to ensure they're scalars (rank 0 tensors) as required by DFT
511+
empty_shape = ctx.make_const(name=utils.make_name('empty_shape'),
512+
np_val=np.array([], dtype=np.int64))
513+
height_fft_length_scalar = ctx.make_node('Reshape',
514+
[height_fft_length.output[0], empty_shape.name],
515+
name=utils.make_name('height_fft_length_scalar'))
516+
width_fft_length_scalar = ctx.make_node('Reshape',
517+
[width_fft_length.output[0], empty_shape.name],
518+
name=utils.make_name('width_fft_length_scalar'))
519+
498520
# Perform DFT on the second-to-last dimension (height)
499-
# axis = -2 (height dimension)
500-
dft_height = ctx.make_node('DFT', [current_input, height_fft_length.output[0]],
501-
attr={'axis': -2, 'onesided': 1},
521+
# Use onesided=0 for intermediate transforms
522+
# axis = -3 (height dimension)
523+
dft_height = ctx.make_node('DFT', [current_input, height_fft_length_scalar.output[0]],
524+
attr={'axis': -3, 'onesided': 0},
502525
name=utils.make_name('dft_height'))
503526

504527
# Perform DFT on the last dimension (width)
505-
# axis = -2 (width dimension, which becomes -2 after the previous DFT)
506-
dft_result = ctx.make_node('DFT', [dft_height.output[0], width_fft_length.output[0]],
528+
# Use onesided=1 for the second transform to get the one-sided FFT result
529+
# axis = -2 (width dimension)
530+
dft_width = ctx.make_node('DFT', [dft_height.output[0], width_fft_length_scalar.output[0]],
507531
attr={'axis': -2, 'onesided': 1},
508532
name=utils.make_name('dft_width'))
509533

534+
# The DFT output has shape [..., height, width//2+1, 2] where last dim is [real, imag]
535+
# We only need the real part (index 0) of the complex dimension
536+
zero_const_slice = ctx.make_const(name=utils.make_name('zero_slice'),
537+
np_val=np.array([0], dtype=np.int64))
538+
one_const_end = ctx.make_const(name=utils.make_name('one_const_end'),
539+
np_val=np.array([1], dtype=np.int64))
540+
minus_one_axis = ctx.make_const(name=utils.make_name('minus_one_axis'),
541+
np_val=np.array([-1], dtype=np.int64))
542+
543+
# Slice to get real part: dft_width[..., 0:1]
544+
dft_real = ctx.make_node('Slice',
545+
[dft_width.output[0], zero_const_slice.name, one_const_end.name, minus_one_axis.name],
546+
name=utils.make_name('dft_real_part'))
547+
548+
# Squeeze the last dimension to remove the complex dimension
549+
dft_result = GraphBuilder(ctx).make_squeeze(
550+
{'data': dft_real.output[0], 'axes': [-1]},
551+
name=utils.make_name('dft_result_squeezed'),
552+
return_node=True)
553+
510554
# Remove the original node and replace with DFT result
511555
ctx.remove_node(node.name)
512556
ctx.replace_all_inputs(node.output[0], dft_result.output[0])

0 commit comments

Comments
 (0)