Skip to content

Commit 8a9fc78

Browse files
authored
perf: skip bincount if unnecessary (#4773)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Bug Fixes** - Improved the aggregation logic to conditionally compute bin counts, enhancing accuracy when averaging or specifying the number of owners. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 674ebad commit 8a9fc78

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

deepmd/pt/model/network/utils.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,22 @@ def aggregate(
3030
-------
3131
output: [num_owner, feature_dim]
3232
"""
33-
bin_count = torch.bincount(owners)
34-
bin_count = bin_count.where(bin_count != 0, bin_count.new_ones(1))
35-
36-
if (num_owner is not None) and (bin_count.shape[0] != num_owner):
37-
difference = num_owner - bin_count.shape[0]
38-
bin_count = torch.cat([bin_count, bin_count.new_ones(difference)])
39-
40-
# make sure this operation is done on the same device of data and owners
41-
output = data.new_zeros([bin_count.shape[0], data.shape[1]])
33+
if num_owner is None or average:
34+
# requires bincount
35+
bin_count = torch.bincount(owners)
36+
bin_count = bin_count.where(bin_count != 0, bin_count.new_ones(1))
37+
if (num_owner is not None) and (bin_count.shape[0] != num_owner):
38+
difference = num_owner - bin_count.shape[0]
39+
bin_count = torch.cat([bin_count, bin_count.new_ones(difference)])
40+
else:
41+
num_owner = bin_count.shape[0]
42+
else:
43+
bin_count = None
44+
45+
output = data.new_zeros([num_owner, data.shape[1]])
4246
output = output.index_add_(0, owners, data)
4347
if average:
48+
assert bin_count is not None
4449
output = (output.T / bin_count).T
4550
return output
4651

0 commit comments

Comments
 (0)