-
Notifications
You must be signed in to change notification settings - Fork 581
perf: gather node embedding before matmul #4744
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 8 commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
ad23558
perf: gather node embedding before matmul
caic99 b543cc8
swap dim
caic99 09564f3
extract nei_node_ebd
caic99 1888cd2
remove node_ebd_ext
caic99 7569f78
not extend atype for training
caic99 93e6851
remove debug statements
caic99 cbaf7fd
remove duplicated statements
caic99 29667c5
add param use_ext_ebd
caic99 511c207
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 8804e7f
add args impl for dpmodel
caic99 df875cf
Merge branch 'fix-node-ebd-ext' of https://github.com/caic99/deepmd-k…
caic99 7b7cdfa
fix dpmodel repflow serialization
caic99 a746f82
fix UT
caic99 5594c56
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] ac6677e
refactor node_ebd_ext ctor
caic99 a349de1
feat(pt): add use_loc_mapping
iProzd dd0d642
solve merge conflict
caic99 5e2a9e6
modify ut
caic99 3ccefcc
fix name in dpmodel
caic99 b6ca445
fix ut
caic99 20fa93c
delete unused files
caic99 d751e44
fix unused params
caic99 7ba9580
format ut
caic99 91611e2
fix mapping
caic99 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -435,9 +435,8 @@ def optim_angle_update( | |
| def optim_edge_update( | ||
| self, | ||
| node_ebd: torch.Tensor, | ||
| node_ebd_ext: torch.Tensor, | ||
| nei_node_ebd: torch.Tensor, | ||
| edge_ebd: torch.Tensor, | ||
| nlist: torch.Tensor, | ||
| feat: str = "node", | ||
| ) -> torch.Tensor: | ||
| if feat == "node": | ||
|
|
@@ -455,10 +454,8 @@ def optim_edge_update( | |
|
|
||
| # nf * nloc * node/edge_dim | ||
| sub_node_update = torch.matmul(node_ebd, node) | ||
| # nf * nall * node/edge_dim | ||
| sub_node_ext_update = torch.matmul(node_ebd_ext, node_ext) | ||
| # nf * nloc * nnei * node/edge_dim | ||
| sub_node_ext_update = _make_nei_g1(sub_node_ext_update, nlist) | ||
| # nf * nloc * node/edge_dim | ||
| sub_node_ext_update = torch.matmul(nei_node_ebd, node_ext) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You missed a _make_nei_g1 here. But when you use locak mapping, it's more efficient to keep the old implementation. |
||
| # nf * nloc * nnei * node/edge_dim | ||
| sub_edge_update = torch.matmul(edge_ebd, edge) | ||
|
|
||
|
|
@@ -469,7 +466,8 @@ def optim_edge_update( | |
|
|
||
| def forward( | ||
| self, | ||
| node_ebd_ext: torch.Tensor, # nf x nall x n_dim | ||
| node_ebd: torch.Tensor, # nf x nloc x n_dim | ||
| node_ebd_ext: Optional[torch.Tensor], # nf x nall x n_dim | ||
| edge_ebd: torch.Tensor, # nf x nloc x nnei x e_dim | ||
| h2: torch.Tensor, # nf x nloc x nnei x 3 | ||
| angle_ebd: torch.Tensor, # nf x nloc x a_nnei x a_nnei x a_dim | ||
|
|
@@ -514,8 +512,6 @@ def forward( | |
| Updated angle embedding. | ||
| """ | ||
| nb, nloc, nnei, _ = edge_ebd.shape | ||
| nall = node_ebd_ext.shape[1] | ||
| node_ebd = node_ebd_ext[:, :nloc, :] | ||
| assert (nb, nloc) == node_ebd.shape[:2] | ||
| assert (nb, nloc, nnei) == h2.shape[:3] | ||
| del a_nlist # may be used in the future | ||
|
|
@@ -527,8 +523,10 @@ def forward( | |
| # node self mlp | ||
| node_self_mlp = self.act(self.node_self_mlp(node_ebd)) | ||
| n_update_list.append(node_self_mlp) | ||
|
|
||
| nei_node_ebd = _make_nei_g1(node_ebd_ext, nlist) | ||
| if node_ebd_ext is not None: | ||
caic99 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| nei_node_ebd = _make_nei_g1(node_ebd_ext, nlist) | ||
| else: | ||
| nei_node_ebd = _make_nei_g1(node_ebd, nlist) | ||
|
|
||
| # node sym (grrg + drrd) | ||
| node_sym_list: list[torch.Tensor] = [] | ||
|
|
@@ -577,9 +575,8 @@ def forward( | |
| node_edge_update = self.act( | ||
| self.optim_edge_update( | ||
| node_ebd, | ||
| node_ebd_ext, | ||
| nei_node_ebd, | ||
| edge_ebd, | ||
| nlist, | ||
| "node", | ||
| ) | ||
| ) * sw.unsqueeze(-1) | ||
|
|
@@ -605,9 +602,8 @@ def forward( | |
| edge_self_update = self.act( | ||
| self.optim_edge_update( | ||
| node_ebd, | ||
| node_ebd_ext, | ||
| nei_node_ebd, | ||
| edge_ebd, | ||
| nlist, | ||
| "edge", | ||
| ) | ||
| ) | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe you need to abort modifications in this file, which means to keep exact modifications in commit iProzd@28803d9 and I passed all the uts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you mean you'll submit another PR to this branch?