From 34e0a21a7c6c059f251f7e1c247dc5615a99d9ea Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Mon, 13 Oct 2025 11:31:06 +0200 Subject: [PATCH 1/3] Use output_size argument in torch.repeat_interleave when index is provided --- torch_geometric/utils/_softmax.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_geometric/utils/_softmax.py b/torch_geometric/utils/_softmax.py index c6f19f8e4a0d..508bfcac2b23 100644 --- a/torch_geometric/utils/_softmax.py +++ b/torch_geometric/utils/_softmax.py @@ -65,11 +65,12 @@ def softmax( size = ([1] * dim) + [-1] count = ptr[1:] - ptr[:-1] ptr = ptr.view(size) + output_size = index.shape[dim] if index is not None else None src_max = segment(src.detach(), ptr, reduce='max') - src_max = src_max.repeat_interleave(count, dim=dim) + src_max = src_max.repeat_interleave(count, dim=dim, output_size=output_size) out = (src - src_max).exp() out_sum = segment(out, ptr, reduce='sum') + 1e-16 - out_sum = out_sum.repeat_interleave(count, dim=dim) + out_sum = out_sum.repeat_interleave(count, dim=dim, output_size=output_size) elif index is not None: N = maybe_num_nodes(index, num_nodes) src_max = scatter(src.detach(), index, dim, dim_size=N, reduce='max') From bb9e9be2850883dd2df298a261ba1880564ae705 Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Mon, 13 Oct 2025 11:31:14 +0200 Subject: [PATCH 2/3] Extend test --- test/utils/test_softmax.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/utils/test_softmax.py b/test/utils/test_softmax.py index 7102d3203f07..a7033e78c9f5 100644 --- a/test/utils/test_softmax.py +++ b/test/utils/test_softmax.py @@ -17,11 +17,13 @@ def test_softmax(): out = softmax(src, index) assert out.tolist() == [0.5, 0.5, 1, 1] assert softmax(src, ptr=ptr).tolist() == out.tolist() + assert softmax(src, index=index, ptr=ptr).tolist() == out.tolist() src = src.view(-1, 1) out = softmax(src, index) assert out.tolist() == [[0.5], [0.5], [1], [1]] assert softmax(src, ptr=ptr).tolist() == out.tolist() + assert softmax(src, index=index, ptr=ptr).tolist() == out.tolist() jit = torch.jit.script(softmax) assert torch.allclose(jit(src, index), out) From 88d903994a474eae4505eb51abe7df7add8eb950 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 Oct 2025 09:47:15 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_geometric/utils/_softmax.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch_geometric/utils/_softmax.py b/torch_geometric/utils/_softmax.py index 508bfcac2b23..4de7c8432a5a 100644 --- a/torch_geometric/utils/_softmax.py +++ b/torch_geometric/utils/_softmax.py @@ -67,10 +67,12 @@ def softmax( ptr = ptr.view(size) output_size = index.shape[dim] if index is not None else None src_max = segment(src.detach(), ptr, reduce='max') - src_max = src_max.repeat_interleave(count, dim=dim, output_size=output_size) + src_max = src_max.repeat_interleave(count, dim=dim, + output_size=output_size) out = (src - src_max).exp() out_sum = segment(out, ptr, reduce='sum') + 1e-16 - out_sum = out_sum.repeat_interleave(count, dim=dim, output_size=output_size) + out_sum = out_sum.repeat_interleave(count, dim=dim, + output_size=output_size) elif index is not None: N = maybe_num_nodes(index, num_nodes) src_max = scatter(src.detach(), index, dim, dim_size=N, reduce='max')