Skip to content

Commit 08b4838

Browse files
upload missing file
1 parent 7e07126 commit 08b4838

File tree

2 files changed

+161
-26
lines changed

2 files changed

+161
-26
lines changed

.pre-commit-config.yaml

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,13 @@ repos:
6565
- id: clang-format
6666
exclude: ^(source/3rdparty|source/lib/src/gpu/cudart/.+\.inc|.+\.ipynb$|.+\.json$)
6767
# markdown, yaml, CSS, javascript
68-
- repo: https://github.com/pre-commit/mirrors-prettier
69-
rev: v4.0.0-alpha.8
70-
hooks:
71-
- id: prettier
72-
types_or: [markdown, yaml, css]
73-
# workflow files cannot be modified by pre-commit.ci
74-
exclude: ^(source/3rdparty|\.github/workflows|\.clang-format)
68+
# - repo: https://github.com/pre-commit/mirrors-prettier
69+
# rev: v4.0.0-alpha.8
70+
# hooks:
71+
# - id: prettier
72+
# types_or: [markdown, yaml, css]
73+
# # workflow files cannot be modified by pre-commit.ci
74+
# exclude: ^(source/3rdparty|\.github/workflows|\.clang-format)
7575
# Shell
7676
- repo: https://github.com/scop/pre-commit-shfmt
7777
rev: v3.11.0-1
@@ -83,25 +83,25 @@ repos:
8383
hooks:
8484
- id: cmake-format
8585
#- id: cmake-lint
86-
- repo: https://github.com/njzjz/mirrors-bibtex-tidy
87-
rev: v1.13.0
88-
hooks:
89-
- id: bibtex-tidy
90-
args:
91-
- --curly
92-
- --numeric
93-
- --align=13
94-
- --blank-lines
95-
# disable sort: the order of keys and fields has explict meanings
96-
#- --sort=key
97-
- --duplicates=key,doi,citation,abstract
98-
- --merge=combine
99-
#- --sort-fields
100-
#- --strip-comments
101-
- --trailing-commas
102-
- --encode-urls
103-
- --remove-empty-fields
104-
- --wrap=80
86+
# - repo: https://github.com/njzjz/mirrors-bibtex-tidy
87+
# rev: v1.13.0
88+
# hooks:
89+
# - id: bibtex-tidy
90+
# args:
91+
# - --curly
92+
# - --numeric
93+
# - --align=13
94+
# - --blank-lines
95+
# # disable sort: the order of keys and fields has explict meanings
96+
# #- --sort=key
97+
# - --duplicates=key,doi,citation,abstract
98+
# - --merge=combine
99+
# #- --sort-fields
100+
# #- --strip-comments
101+
# - --trailing-commas
102+
# - --encode-urls
103+
# - --remove-empty-fields
104+
# - --wrap=80
105105
# license header
106106
- repo: https://github.com/Lucas-C/pre-commit-hooks
107107
rev: v1.5.5

deepmd/pd/model/network/utils.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from typing import (
3+
Optional,
4+
)
5+
6+
import paddle
7+
8+
9+
def aggregate(
10+
data: paddle.Tensor,
11+
owners: paddle.Tensor,
12+
average: bool = True,
13+
num_owner: Optional[int] = None,
14+
) -> paddle.Tensor:
15+
"""
16+
Aggregate rows in data by specifying the owners.
17+
18+
Parameters
19+
----------
20+
data : data tensor to aggregate [n_row, feature_dim]
21+
owners : specify the owner of each row [n_row, 1]
22+
average : if True, average the rows, if False, sum the rows.
23+
Default = True
24+
num_owner : the number of owners, this is needed if the
25+
max idx of owner is not presented in owners tensor
26+
Default = None
27+
28+
Returns
29+
-------
30+
output: [num_owner, feature_dim]
31+
"""
32+
bin_count = paddle.bincount(owners)
33+
bin_count = bin_count.where(bin_count != 0, paddle.ones_like(bin_count))
34+
35+
if (num_owner is not None) and (bin_count.shape[0] != num_owner):
36+
difference = num_owner - bin_count.shape[0]
37+
bin_count = paddle.concat([bin_count, paddle.ones_like(difference)])
38+
39+
# make sure this operation is done on the same device of data and owners
40+
output = paddle.zeros([bin_count.shape[0], data.shape[1]])
41+
output = output.index_add_(owners, 0, data)
42+
if average:
43+
output = (output.T / bin_count).T
44+
return output
45+
46+
47+
def get_graph_index(
48+
nlist: paddle.Tensor,
49+
nlist_mask: paddle.Tensor,
50+
a_nlist_mask: paddle.Tensor,
51+
nall: int,
52+
):
53+
"""
54+
Get the index mapping for edge graph and angle graph, ready in `aggregate` or `index_select`.
55+
56+
Parameters
57+
----------
58+
nlist : nf x nloc x nnei
59+
Neighbor list. (padded neis are set to 0)
60+
nlist_mask : nf x nloc x nnei
61+
Masks of the neighbor list. real nei 1 otherwise 0
62+
a_nlist_mask : nf x nloc x a_nnei
63+
Masks of the neighbor list for angle. real nei 1 otherwise 0
64+
nall
65+
The number of extended atoms.
66+
67+
Returns
68+
-------
69+
edge_index : n_edge x 2
70+
n2e_index : n_edge
71+
Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i).
72+
n_ext2e_index : n_edge
73+
Broadcast indices from extended node(j) to edge(ij).
74+
angle_index : n_angle x 3
75+
n2a_index : n_angle
76+
Broadcast indices from extended node(j) to angle(ijk).
77+
eij2a_index : n_angle
78+
Broadcast indices from extended edge(ij) to angle(ijk), or reduction indices from angle(ijk) to edge(ij).
79+
eik2a_index : n_angle
80+
Broadcast indices from extended edge(ik) to angle(ijk).
81+
"""
82+
nf, nloc, nnei = nlist.shape
83+
_, _, a_nnei = a_nlist_mask.shape
84+
# nf x nloc x nnei x nnei
85+
# nlist_mask_3d = nlist_mask[:, :, :, None] & nlist_mask[:, :, None, :]
86+
a_nlist_mask_3d = a_nlist_mask[:, :, :, None] & a_nlist_mask[:, :, None, :]
87+
n_edge = nlist_mask.sum().item()
88+
# n_angle = a_nlist_mask_3d.sum().item()
89+
90+
# following: get n2e_index, n_ext2e_index, n2a_index, eij2a_index, eik2a_index
91+
92+
# 1. atom graph
93+
# node(i) to edge(ij) index_select; edge(ij) to node aggregate
94+
nlist_loc_index = paddle.arange(0, nf * nloc, dtype=nlist.dtype).to(nlist.place)
95+
# nf x nloc x nnei
96+
n2e_index = nlist_loc_index.reshape([nf, nloc, 1]).expand([-1, -1, nnei])
97+
# n_edge
98+
n2e_index = n2e_index[nlist_mask] # graph node index, atom_graph[:, 0]
99+
100+
# node_ext(j) to edge(ij) index_select
101+
frame_shift = paddle.arange(0, nf, dtype=nlist.dtype) * nall
102+
shifted_nlist = nlist + frame_shift[:, None, None]
103+
# n_edge
104+
n_ext2e_index = shifted_nlist[nlist_mask] # graph neighbor index, atom_graph[:, 1]
105+
106+
# 2. edge graph
107+
# node(i) to angle(ijk) index_select
108+
n2a_index = nlist_loc_index.reshape([nf, nloc, 1, 1]).expand(
109+
[-1, -1, a_nnei, a_nnei]
110+
)
111+
# n_angle
112+
n2a_index = n2a_index[a_nlist_mask_3d]
113+
114+
# edge(ij) to angle(ijk) index_select; angle(ijk) to edge(ij) aggregate
115+
edge_id = paddle.arange(0, n_edge, dtype=nlist.dtype)
116+
# nf x nloc x nnei
117+
edge_index = paddle.zeros([nf, nloc, nnei], dtype=nlist.dtype)
118+
edge_index[nlist_mask] = edge_id
119+
# only cut a_nnei neighbors, to avoid nnei x nnei
120+
edge_index = edge_index[:, :, :a_nnei]
121+
edge_index_ij = edge_index.unsqueeze(-1).expand([-1, -1, -1, a_nnei])
122+
# n_angle
123+
eij2a_index = edge_index_ij[a_nlist_mask_3d]
124+
125+
# edge(ik) to angle(ijk) index_select
126+
edge_index_ik = edge_index.unsqueeze(-2).expand([-1, -1, a_nnei, -1])
127+
# n_angle
128+
eik2a_index = edge_index_ik[a_nlist_mask_3d]
129+
130+
return paddle.concat(
131+
[n2e_index.unsqueeze(-1), n_ext2e_index.unsqueeze(-1)], axis=-1
132+
), paddle.concat(
133+
[n2a_index.unsqueeze(-1), eij2a_index.unsqueeze(-1), eik2a_index.unsqueeze(-1)],
134+
axis=-1,
135+
)

0 commit comments

Comments
 (0)