-
Notifications
You must be signed in to change notification settings - Fork 581
perf: gather node embedding before matmul #4744
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
ad23558
b543cc8
09564f3
1888cd2
7569f78
93e6851
cbaf7fd
29667c5
511c207
8804e7f
df875cf
7b7cdfa
a746f82
5594c56
ac6677e
a349de1
dd0d642
5e2a9e6
3ccefcc
b6ca445
20fa93c
d751e44
7ba9580
91611e2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe you need to abort modifications in this file, which means to keep exact modifications in commit iProzd@28803d9 and I passed all the uts.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Did you mean you'll submit another PR to this branch? |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -435,9 +435,8 @@ def optim_angle_update( | |
| def optim_edge_update( | ||
| self, | ||
| node_ebd: torch.Tensor, | ||
| node_ebd_ext: torch.Tensor, | ||
| nei_node_ebd: torch.Tensor, | ||
| edge_ebd: torch.Tensor, | ||
| nlist: torch.Tensor, | ||
| feat: str = "node", | ||
| ) -> torch.Tensor: | ||
| if feat == "node": | ||
|
|
@@ -455,10 +454,8 @@ def optim_edge_update( | |
|
|
||
| # nf * nloc * node/edge_dim | ||
| sub_node_update = torch.matmul(node_ebd, node) | ||
| # nf * nall * node/edge_dim | ||
| sub_node_ext_update = torch.matmul(node_ebd_ext, node_ext) | ||
| # nf * nloc * nnei * node/edge_dim | ||
| sub_node_ext_update = _make_nei_g1(sub_node_ext_update, nlist) | ||
| # nf * nloc * node/edge_dim | ||
| sub_node_ext_update = torch.matmul(nei_node_ebd, node_ext) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You missed a _make_nei_g1 here. But when you use locak mapping, it's more efficient to keep the old implementation. |
||
| # nf * nloc * nnei * node/edge_dim | ||
| sub_edge_update = torch.matmul(edge_ebd, edge) | ||
|
|
||
|
|
@@ -469,7 +466,8 @@ def optim_edge_update( | |
|
|
||
| def forward( | ||
| self, | ||
| node_ebd_ext: torch.Tensor, # nf x nall x n_dim | ||
| node_ebd: torch.Tensor, # nf x nloc x n_dim | ||
| node_ebd_ext: Optional[torch.Tensor], # nf x nall x n_dim | ||
| edge_ebd: torch.Tensor, # nf x nloc x nnei x e_dim | ||
| h2: torch.Tensor, # nf x nloc x nnei x 3 | ||
| angle_ebd: torch.Tensor, # nf x nloc x a_nnei x a_nnei x a_dim | ||
|
|
@@ -514,8 +512,6 @@ def forward( | |
| Updated angle embedding. | ||
| """ | ||
| nb, nloc, nnei, _ = edge_ebd.shape | ||
| nall = node_ebd_ext.shape[1] | ||
| node_ebd = node_ebd_ext[:, :nloc, :] | ||
| assert (nb, nloc) == node_ebd.shape[:2] | ||
| assert (nb, nloc, nnei) == h2.shape[:3] | ||
| del a_nlist # may be used in the future | ||
|
|
@@ -527,8 +523,10 @@ def forward( | |
| # node self mlp | ||
| node_self_mlp = self.act(self.node_self_mlp(node_ebd)) | ||
| n_update_list.append(node_self_mlp) | ||
|
|
||
| nei_node_ebd = _make_nei_g1(node_ebd_ext, nlist) | ||
| if node_ebd_ext is not None: | ||
caic99 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| nei_node_ebd = _make_nei_g1(node_ebd_ext, nlist) | ||
| else: | ||
| nei_node_ebd = _make_nei_g1(node_ebd, nlist) | ||
|
|
||
| # node sym (grrg + drrd) | ||
| node_sym_list: list[torch.Tensor] = [] | ||
|
|
@@ -577,9 +575,8 @@ def forward( | |
| node_edge_update = self.act( | ||
| self.optim_edge_update( | ||
| node_ebd, | ||
| node_ebd_ext, | ||
| nei_node_ebd, | ||
| edge_ebd, | ||
| nlist, | ||
| "node", | ||
| ) | ||
| ) * sw.unsqueeze(-1) | ||
|
|
@@ -605,9 +602,8 @@ def forward( | |
| edge_self_update = self.act( | ||
| self.optim_edge_update( | ||
| node_ebd, | ||
| node_ebd_ext, | ||
| nei_node_ebd, | ||
| edge_ebd, | ||
| nlist, | ||
| "edge", | ||
| ) | ||
| ) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.