|
| 1 | +# SPDX-License-Identifier: LGPL-3.0-or-later |
| 2 | +from typing import ( |
| 3 | + Optional, |
| 4 | +) |
| 5 | + |
| 6 | +import paddle |
| 7 | + |
| 8 | + |
| 9 | +def aggregate( |
| 10 | + data: paddle.Tensor, |
| 11 | + owners: paddle.Tensor, |
| 12 | + average: bool = True, |
| 13 | + num_owner: Optional[int] = None, |
| 14 | +) -> paddle.Tensor: |
| 15 | + """ |
| 16 | + Aggregate rows in data by specifying the owners. |
| 17 | +
|
| 18 | + Parameters |
| 19 | + ---------- |
| 20 | + data : data tensor to aggregate [n_row, feature_dim] |
| 21 | + owners : specify the owner of each row [n_row, 1] |
| 22 | + average : if True, average the rows, if False, sum the rows. |
| 23 | + Default = True |
| 24 | + num_owner : the number of owners, this is needed if the |
| 25 | + max idx of owner is not presented in owners tensor |
| 26 | + Default = None |
| 27 | +
|
| 28 | + Returns |
| 29 | + ------- |
| 30 | + output: [num_owner, feature_dim] |
| 31 | + """ |
| 32 | + bin_count = paddle.bincount(owners) |
| 33 | + bin_count = bin_count.where(bin_count != 0, paddle.ones_like(bin_count)) |
| 34 | + |
| 35 | + if (num_owner is not None) and (bin_count.shape[0] != num_owner): |
| 36 | + difference = num_owner - bin_count.shape[0] |
| 37 | + bin_count = paddle.concat([bin_count, paddle.ones_like(difference)]) |
| 38 | + |
| 39 | + # make sure this operation is done on the same device of data and owners |
| 40 | + output = paddle.zeros([bin_count.shape[0], data.shape[1]]) |
| 41 | + output = output.index_add_(owners, 0, data) |
| 42 | + if average: |
| 43 | + output = (output.T / bin_count).T |
| 44 | + return output |
| 45 | + |
| 46 | + |
| 47 | +def get_graph_index( |
| 48 | + nlist: paddle.Tensor, |
| 49 | + nlist_mask: paddle.Tensor, |
| 50 | + a_nlist_mask: paddle.Tensor, |
| 51 | + nall: int, |
| 52 | +): |
| 53 | + """ |
| 54 | + Get the index mapping for edge graph and angle graph, ready in `aggregate` or `index_select`. |
| 55 | +
|
| 56 | + Parameters |
| 57 | + ---------- |
| 58 | + nlist : nf x nloc x nnei |
| 59 | + Neighbor list. (padded neis are set to 0) |
| 60 | + nlist_mask : nf x nloc x nnei |
| 61 | + Masks of the neighbor list. real nei 1 otherwise 0 |
| 62 | + a_nlist_mask : nf x nloc x a_nnei |
| 63 | + Masks of the neighbor list for angle. real nei 1 otherwise 0 |
| 64 | + nall |
| 65 | + The number of extended atoms. |
| 66 | +
|
| 67 | + Returns |
| 68 | + ------- |
| 69 | + edge_index : n_edge x 2 |
| 70 | + n2e_index : n_edge |
| 71 | + Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i). |
| 72 | + n_ext2e_index : n_edge |
| 73 | + Broadcast indices from extended node(j) to edge(ij). |
| 74 | + angle_index : n_angle x 3 |
| 75 | + n2a_index : n_angle |
| 76 | + Broadcast indices from extended node(j) to angle(ijk). |
| 77 | + eij2a_index : n_angle |
| 78 | + Broadcast indices from extended edge(ij) to angle(ijk), or reduction indices from angle(ijk) to edge(ij). |
| 79 | + eik2a_index : n_angle |
| 80 | + Broadcast indices from extended edge(ik) to angle(ijk). |
| 81 | + """ |
| 82 | + nf, nloc, nnei = nlist.shape |
| 83 | + _, _, a_nnei = a_nlist_mask.shape |
| 84 | + # nf x nloc x nnei x nnei |
| 85 | + # nlist_mask_3d = nlist_mask[:, :, :, None] & nlist_mask[:, :, None, :] |
| 86 | + a_nlist_mask_3d = a_nlist_mask[:, :, :, None] & a_nlist_mask[:, :, None, :] |
| 87 | + n_edge = nlist_mask.sum().item() |
| 88 | + # n_angle = a_nlist_mask_3d.sum().item() |
| 89 | + |
| 90 | + # following: get n2e_index, n_ext2e_index, n2a_index, eij2a_index, eik2a_index |
| 91 | + |
| 92 | + # 1. atom graph |
| 93 | + # node(i) to edge(ij) index_select; edge(ij) to node aggregate |
| 94 | + nlist_loc_index = paddle.arange(0, nf * nloc, dtype=nlist.dtype).to(nlist.place) |
| 95 | + # nf x nloc x nnei |
| 96 | + n2e_index = nlist_loc_index.reshape([nf, nloc, 1]).expand([-1, -1, nnei]) |
| 97 | + # n_edge |
| 98 | + n2e_index = n2e_index[nlist_mask] # graph node index, atom_graph[:, 0] |
| 99 | + |
| 100 | + # node_ext(j) to edge(ij) index_select |
| 101 | + frame_shift = paddle.arange(0, nf, dtype=nlist.dtype) * nall |
| 102 | + shifted_nlist = nlist + frame_shift[:, None, None] |
| 103 | + # n_edge |
| 104 | + n_ext2e_index = shifted_nlist[nlist_mask] # graph neighbor index, atom_graph[:, 1] |
| 105 | + |
| 106 | + # 2. edge graph |
| 107 | + # node(i) to angle(ijk) index_select |
| 108 | + n2a_index = nlist_loc_index.reshape([nf, nloc, 1, 1]).expand( |
| 109 | + [-1, -1, a_nnei, a_nnei] |
| 110 | + ) |
| 111 | + # n_angle |
| 112 | + n2a_index = n2a_index[a_nlist_mask_3d] |
| 113 | + |
| 114 | + # edge(ij) to angle(ijk) index_select; angle(ijk) to edge(ij) aggregate |
| 115 | + edge_id = paddle.arange(0, n_edge, dtype=nlist.dtype) |
| 116 | + # nf x nloc x nnei |
| 117 | + edge_index = paddle.zeros([nf, nloc, nnei], dtype=nlist.dtype) |
| 118 | + edge_index[nlist_mask] = edge_id |
| 119 | + # only cut a_nnei neighbors, to avoid nnei x nnei |
| 120 | + edge_index = edge_index[:, :, :a_nnei] |
| 121 | + edge_index_ij = edge_index.unsqueeze(-1).expand([-1, -1, -1, a_nnei]) |
| 122 | + # n_angle |
| 123 | + eij2a_index = edge_index_ij[a_nlist_mask_3d] |
| 124 | + |
| 125 | + # edge(ik) to angle(ijk) index_select |
| 126 | + edge_index_ik = edge_index.unsqueeze(-2).expand([-1, -1, a_nnei, -1]) |
| 127 | + # n_angle |
| 128 | + eik2a_index = edge_index_ik[a_nlist_mask_3d] |
| 129 | + |
| 130 | + return paddle.concat( |
| 131 | + [n2e_index.unsqueeze(-1), n_ext2e_index.unsqueeze(-1)], axis=-1 |
| 132 | + ), paddle.concat( |
| 133 | + [n2a_index.unsqueeze(-1), eij2a_index.unsqueeze(-1), eik2a_index.unsqueeze(-1)], |
| 134 | + axis=-1, |
| 135 | + ) |
0 commit comments