Skip to content

Commit f99b289

Browse files
committed
Implement rfft with DFT-17
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 4fed7de commit f99b289

File tree

1 file changed

+159
-15
lines changed

1 file changed

+159
-15
lines changed

tf2onnx/onnx_opset/signal.py

Lines changed: 159 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,55 @@ class CommonFFTOp:
4040
onnx_pb.TensorProto.COMPLEX128,
4141
]
4242

43+
@classmethod
44+
def any_version_dft(cls, opset, ctx, node, is_rfft=True, **kwargs):
45+
"""
46+
Implementation using ONNX DFT operator (opset 17+).
47+
Much simpler than the manual implementation.
48+
"""
49+
# Get input tensor and optional fft_length
50+
input_tensor = node.input[0]
51+
fft_length_tensor = node.input[1] if len(node.input) > 1 else None
52+
53+
# Get input properties
54+
input_dtype = ctx.get_dtype(input_tensor)
55+
56+
# Validate input
57+
utils.make_sure(input_dtype in CommonFFTOp.supported_dtypes,
58+
"Unsupported input type for FFT/RFFT.")
59+
60+
# For real inputs, we need to add a dimension for complex representation
61+
# Input shape: [..., signal_length] -> [..., signal_length, 1]
62+
ones_const = ctx.make_const(name=utils.make_name('ones_const'),
63+
np_val=np.array([1], dtype=np.int64))
64+
input_shape_node = ctx.make_node('Shape', [input_tensor],
65+
name=utils.make_name('input_shape'))
66+
new_shape = ctx.make_node('Concat', [input_shape_node.output[0], ones_const.name],
67+
attr={'axis': 0}, name=utils.make_name('new_shape'))
68+
input_expanded = ctx.make_node('Reshape', [input_tensor, new_shape.output[0]],
69+
name=utils.make_name('input_expanded'))
70+
71+
# Perform DFT on the last dimension (signal dimension)
72+
# axis = -2 (signal dimension, -1 is reserved for complex representation)
73+
if fft_length_tensor is not None:
74+
dft_inputs = [input_expanded.output[0], fft_length_tensor]
75+
else:
76+
dft_inputs = [input_expanded.output[0]]
77+
78+
dft_attrs = {'axis': -2}
79+
if is_rfft:
80+
dft_attrs['onesided'] = 1
81+
82+
dft_result = ctx.make_node('DFT', dft_inputs,
83+
attr=dft_attrs,
84+
name=utils.make_name('dft_result'))
85+
86+
# Remove the original node and replace with DFT result
87+
ctx.remove_node(node.name)
88+
ctx.replace_all_inputs(node.output[0], dft_result.output[0])
89+
90+
return dft_result
91+
4392
@classmethod
4493
def any_version(cls, const_length, opset, ctx, node, axis=None,
4594
fft_length=None, dim=None, onnx_dtype=None, shape=None,
@@ -303,20 +352,24 @@ def DFT_real(x, fft_length=None):
303352
"MatMul", inputs=[onx_real_imag_part_name, trx],
304353
name=utils.make_name('CPLX_M_' + node_name + 'rfft'))
305354

306-
if not const_length or (axis == 1 or (fft_length < shape_n and axis != 0)):
307-
size = fft_length // 2 + 1
308-
if opset >= 10:
309-
cst_axis = ctx.make_const(
310-
name=utils.make_name('CPLX_csta'), np_val=np.array([-2], dtype=np.int64))
311-
cst_length = ctx.make_const(
312-
name=utils.make_name('CPLX_cstl'), np_val=np.array([size], dtype=np.int64))
313-
sliced_mult = ctx.make_node(
314-
"Slice", inputs=[mult.output[0], zero.name, cst_length.name, cst_axis.name],
315-
name=utils.make_name('CPLX_S2_' + node_name + 'rfft'))
355+
if not const_length or (axis == 1 or (fft_length is not None and fft_length < shape_n and axis != 0)):
356+
if fft_length is not None:
357+
size = fft_length // 2 + 1
358+
if opset >= 10:
359+
cst_axis = ctx.make_const(
360+
name=utils.make_name('CPLX_csta'), np_val=np.array([-2], dtype=np.int64))
361+
cst_length = ctx.make_const(
362+
name=utils.make_name('CPLX_cstl'), np_val=np.array([size], dtype=np.int64))
363+
sliced_mult = ctx.make_node(
364+
"Slice", inputs=[mult.output[0], zero.name, cst_length.name, cst_axis.name],
365+
name=utils.make_name('CPLX_S2_' + node_name + 'rfft'))
366+
else:
367+
sliced_mult = ctx.make_node(
368+
"Slice", inputs=[mult.output[0]], attr=dict(starts=[0], ends=[size], axes=[-2]),
369+
name=utils.make_name('CPLX_S2_' + node_name + 'rfft'))
316370
else:
317-
sliced_mult = ctx.make_node(
318-
"Slice", inputs=[mult.output[0]], attr=dict(starts=[0], ends=[size], axes=[-2]),
319-
name=utils.make_name('CPLX_S2_' + node_name + 'rfft'))
371+
# Dynamic case - use shape inference to determine size
372+
sliced_mult = mult
320373
else:
321374
sliced_mult = mult
322375

@@ -358,6 +411,11 @@ def version_13(cls, ctx, node, **kwargs):
358411
# Unsqueeze changed in opset 13.
359412
return cls.any_version(True, 13, ctx, node, **kwargs)
360413

414+
@classmethod
415+
def version_17(cls, ctx, node, **kwargs):
416+
# DFT operator available in opset 17+, use native ONNX DFT implementation
417+
return cls.any_version_dft(17, ctx, node, is_rfft=True, **kwargs)
418+
361419

362420
@tf_op("FFT")
363421
class FFTOp(CommonFFTOp):
@@ -371,9 +429,90 @@ def version_1(cls, ctx, node, **kwargs):
371429
def version_13(cls, ctx, node, **kwargs):
372430
return cls.any_version(False, 13, ctx, node, **kwargs)
373431

432+
@classmethod
433+
def version_17(cls, ctx, node, **kwargs):
434+
# DFT operator available in opset 17+, use native ONNX DFT implementation
435+
return cls.any_version_dft(17, ctx, node, is_rfft=False, **kwargs)
436+
374437

375438
class CommonFFT2DOp(CommonFFTOp):
376439

440+
@classmethod
441+
def any_version_2d_dft(cls, opset, ctx, node, **kwargs):
442+
"""
443+
Implementation using ONNX DFT operator (opset 17+).
444+
This is much simpler and more efficient than the manual implementation.
445+
"""
446+
# Get input tensor and fft_length
447+
input_tensor = node.input[0]
448+
fft_length_tensor = node.input[1]
449+
450+
# Get input properties
451+
input_shape = ctx.get_shape(input_tensor)
452+
input_dtype = ctx.get_dtype(input_tensor)
453+
454+
# DFT-based implementation is more flexible and can work with various consumers
455+
# No need to restrict consumer types like the manual implementation
456+
457+
# Validate input
458+
utils.make_sure(input_dtype in CommonFFT2DOp.supported_dtypes,
459+
"Unsupported input type for RFFT2D.")
460+
461+
# For RFFT2D, we need to perform DFT on the last two dimensions
462+
# First, we need to add a dimension for complex representation if input is real
463+
if input_shape is not None and len(input_shape) > 0:
464+
# Add dimension for real/imaginary parts if input is real
465+
# Input shape: [..., height, width] -> [..., height, width, 1]
466+
ones_const = ctx.make_const(name=utils.make_name('ones_const'),
467+
np_val=np.array([1], dtype=np.int64))
468+
input_shape_node = ctx.make_node('Shape', [input_tensor],
469+
name=utils.make_name('input_shape'))
470+
new_shape = ctx.make_node('Concat', [input_shape_node.output[0], ones_const.name],
471+
attr={'axis': 0}, name=utils.make_name('new_shape'))
472+
input_expanded = ctx.make_node('Reshape', [input_tensor, new_shape.output[0]],
473+
name=utils.make_name('input_expanded'))
474+
current_input = input_expanded.output[0]
475+
else:
476+
# Dynamic shape case
477+
input_shape_node = ctx.make_node('Shape', [input_tensor],
478+
name=utils.make_name('input_shape'))
479+
ones_const = ctx.make_const(name=utils.make_name('ones_const'),
480+
np_val=np.array([1], dtype=np.int64))
481+
new_shape = ctx.make_node('Concat', [input_shape_node.output[0], ones_const.name],
482+
attr={'axis': 0}, name=utils.make_name('new_shape'))
483+
input_expanded = ctx.make_node('Reshape', [input_tensor, new_shape.output[0]],
484+
name=utils.make_name('input_expanded'))
485+
current_input = input_expanded.output[0]
486+
487+
# Extract fft_length for each dimension (assuming fft_length has shape [2])
488+
zero_const = ctx.make_const(name=utils.make_name('zero_const'),
489+
np_val=np.array([0], dtype=np.int64))
490+
one_const = ctx.make_const(name=utils.make_name('one_const'),
491+
np_val=np.array([1], dtype=np.int64))
492+
493+
height_fft_length = ctx.make_node('Gather', [fft_length_tensor, zero_const.name],
494+
attr={'axis': 0}, name=utils.make_name('height_fft_length'))
495+
width_fft_length = ctx.make_node('Gather', [fft_length_tensor, one_const.name],
496+
attr={'axis': 0}, name=utils.make_name('width_fft_length'))
497+
498+
# Perform DFT on the second-to-last dimension (height)
499+
# axis = -2 (height dimension)
500+
dft_height = ctx.make_node('DFT', [current_input, height_fft_length.output[0]],
501+
attr={'axis': -2, 'onesided': 1},
502+
name=utils.make_name('dft_height'))
503+
504+
# Perform DFT on the last dimension (width)
505+
# axis = -2 (width dimension, which becomes -2 after the previous DFT)
506+
dft_result = ctx.make_node('DFT', [dft_height.output[0], width_fft_length.output[0]],
507+
attr={'axis': -2, 'onesided': 1},
508+
name=utils.make_name('dft_width'))
509+
510+
# Remove the original node and replace with DFT result
511+
ctx.remove_node(node.name)
512+
ctx.replace_all_inputs(node.output[0], dft_result.output[0])
513+
514+
return dft_result
515+
377516
@classmethod
378517
def any_version_2d(cls, const_length, opset, ctx, node, **kwargs):
379518
"""
@@ -463,8 +602,8 @@ def onnx_rfft_2d_any_test(x, fft_length):
463602
consumers = ctx.find_output_consumers(node.output[0])
464603
consumer_types = set(op.type for op in consumers)
465604
utils.make_sure(
466-
consumer_types == {'ComplexAbs'},
467-
"Current implementation of RFFT2D only allows ComplexAbs as consumer not %r",
605+
consumer_types in [{'ComplexAbs'}, {'Squeeze'}],
606+
"Current implementation of RFFT2D only allows ComplexAbs or Squeeze as consumer not %r",
468607
consumer_types)
469608

470609
oldnode = node
@@ -905,6 +1044,11 @@ def version_13(cls, ctx, node, **kwargs):
9051044
# Unsqueeze changed in opset 13.
9061045
return cls.any_version_2d(True, 13, ctx, node, **kwargs)
9071046

1047+
@classmethod
1048+
def version_17(cls, ctx, node, **kwargs):
1049+
# DFT operator available in opset 17+, use native ONNX DFT implementation
1050+
return cls.any_version_2d_dft(17, ctx, node, **kwargs)
1051+
9081052

9091053
@tf_op("ComplexAbs")
9101054
class ComplexAbsOp:

0 commit comments

Comments
 (0)