-
Notifications
You must be signed in to change notification settings - Fork 582
Chore: refactor dpmodel model #4296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
anyangml
merged 21 commits into
deepmodeling:devel
from
anyangml:chore/refactor-dpmodel-model
Nov 19, 2024
Merged
Changes from 7 commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
88f5079
chore: refactor dpmodel atomic model
anyangml de96159
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] ab3bd29
feat: expose dos/polar/dipole model in dpmodel
anyangml 8e85e59
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] c5dab77
chore: refactor atomic model fitting assertion
anyangml 58b987e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] a285d63
Merge branch 'devel' into chore/refactor-dpmodel-model
anyangml e2e9d60
chore: fix typo
anyangml 47114ec
feat: try add dos consistent UT
anyangml 48d8d6c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 81bd024
fix: UT para
anyangml 7490529
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] c040df3
fix: UT res extract
anyangml 8940dac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 4802a9a
fix: add dos consistency UT
anyangml 3355f50
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 5a11ff8
fix: expose more dpmodel interface
anyangml 9071e73
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 1dc7fd9
chore: vectorize for loop
anyangml 0afee21
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 33452bd
Merge branch 'devel' into chore/refactor-dpmodel-model
anyangml File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
| import numpy as np | ||
|
|
||
| from deepmd.dpmodel.fitting.dipole_fitting import ( | ||
| DipoleFitting, | ||
| ) | ||
|
|
||
| from .dp_atomic_model import ( | ||
| DPAtomicModel, | ||
| ) | ||
|
|
||
|
|
||
| class DPDipoleAtomicModel(DPAtomicModel): | ||
| def __init__(self, descriptor, fitting, type_map, **kwargs): | ||
| if not isinstance(fitting, DipoleFitting): | ||
| raise TypeError( | ||
| "fitting must be an instance of DipoleFitting for DPDipoleAtomicModel" | ||
| ) | ||
| super().__init__(descriptor, fitting, type_map, **kwargs) | ||
|
|
||
| def apply_out_stat( | ||
| self, | ||
| ret: dict[str, np.ndarray], | ||
| atype: np.ndarray, | ||
| ): | ||
| # dipole not applying bias | ||
| return ret |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
| from deepmd.dpmodel.fitting.dos_fitting import ( | ||
| DOSFittingNet, | ||
| ) | ||
|
|
||
| from .dp_atomic_model import ( | ||
| DPAtomicModel, | ||
| ) | ||
|
|
||
|
|
||
| class DPDOSAtomicModel(DPAtomicModel): | ||
| def __init__(self, descriptor, fitting, type_map, **kwargs): | ||
| if not isinstance(fitting, DOSFittingNet): | ||
| raise TypeError( | ||
| "fitting must be an instance of DOSFittingNet for DPDOSAtomicModel" | ||
| ) | ||
| super().__init__(descriptor, fitting, type_map, **kwargs) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
| from deepmd.dpmodel.fitting.ener_fitting import ( | ||
| EnergyFittingNet, | ||
| InvarFitting, | ||
| ) | ||
|
|
||
| from .dp_atomic_model import ( | ||
| DPAtomicModel, | ||
| ) | ||
|
|
||
|
|
||
| class DPEnergyAtomicModel(DPAtomicModel): | ||
| def __init__(self, descriptor, fitting, type_map, **kwargs): | ||
| if not ( | ||
| isinstance(fitting, EnergyFittingNet) or isinstance(fitting, InvarFitting) | ||
| ): | ||
| raise TypeError( | ||
| "fitting must be an instance of EnergyFittingNet or InvarFitting for DPEnergyAtomicModel" | ||
| ) | ||
| super().__init__(descriptor, fitting, type_map, **kwargs) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
|
|
||
| import numpy as np | ||
|
|
||
| from deepmd.dpmodel.fitting.polarizability_fitting import ( | ||
| PolarFitting, | ||
| ) | ||
|
|
||
| from .dp_atomic_model import ( | ||
| DPAtomicModel, | ||
| ) | ||
|
|
||
|
|
||
| class DPPolarAtomicModel(DPAtomicModel): | ||
| def __init__(self, descriptor, fitting, type_map, **kwargs): | ||
| if not isinstance(fitting, PolarFitting): | ||
| raise TypeError( | ||
| "fitting must be an instance of PolarFitting for DPPolarAtomicModel" | ||
| ) | ||
| super().__init__(descriptor, fitting, type_map, **kwargs) | ||
|
|
||
| def apply_out_stat( | ||
| self, | ||
| ret: dict[str, np.ndarray], | ||
| atype: np.ndarray, | ||
| ): | ||
| """Apply the stat to each atomic output. | ||
| Parameters | ||
| ---------- | ||
| ret | ||
| The returned dict by the forward_atomic method | ||
| atype | ||
| The atom types. nf x nloc | ||
| """ | ||
| out_bias, out_std = self._fetch_out_stat(self.bias_keys) | ||
|
|
||
| if self.fitting_net.shift_diag: | ||
| nframes, nloc = atype.shape | ||
| dtype = out_bias[self.bias_keys[0]].dtype | ||
| for kk in self.bias_keys: | ||
| ntypes = out_bias[kk].shape[0] | ||
| temp = np.zeros(ntypes, dtype=dtype) | ||
|
||
| for i in range(ntypes): | ||
| temp[i] = np.mean(np.diagonal(out_bias[kk][i].reshape(3, 3))) | ||
anyangml marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
anyangml marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| modified_bias = temp[atype] | ||
|
|
||
| # (nframes, nloc, 1) | ||
| modified_bias = ( | ||
| modified_bias[..., np.newaxis] * (self.fitting_net.scale[atype]) | ||
| ) | ||
|
|
||
| eye = np.eye(3, dtype=dtype) | ||
| eye = np.tile(eye, (nframes, nloc, 1, 1)) | ||
| # (nframes, nloc, 3, 3) | ||
| modified_bias = modified_bias[..., np.newaxis] * eye | ||
|
|
||
| # nf x nloc x odims, out_bias: ntypes x odims | ||
| ret[kk] = ret[kk] + modified_bias | ||
| return ret | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
|
|
||
|
|
||
| from deepmd.dpmodel.atomic_model import ( | ||
| DPDipoleAtomicModel, | ||
| ) | ||
| from deepmd.dpmodel.model.model import ( | ||
| BaseModel, | ||
| ) | ||
|
|
||
| from .dp_model import ( | ||
| DPModelCommon, | ||
| ) | ||
| from .make_model import ( | ||
| make_model, | ||
| ) | ||
|
|
||
| DPDipoleModel_ = make_model(DPDipoleAtomicModel) | ||
|
|
||
|
|
||
| @BaseModel.register("dipole") | ||
| class DipoleModel(DPModelCommon, DPDipoleModel_): | ||
| model_type = "dipole" | ||
anyangml marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def __init__( | ||
| self, | ||
| *args, | ||
| **kwargs, | ||
| ): | ||
| DPModelCommon.__init__(self) | ||
| DPDipoleModel_.__init__(self, *args, **kwargs) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
|
|
||
| from deepmd.dpmodel.atomic_model import ( | ||
| DPDOSAtomicModel, | ||
| ) | ||
| from deepmd.dpmodel.model.model import ( | ||
| BaseModel, | ||
| ) | ||
|
|
||
| from .dp_model import ( | ||
| DPModelCommon, | ||
| ) | ||
| from .make_model import ( | ||
| make_model, | ||
| ) | ||
|
|
||
| DPDOSModel_ = make_model(DPDOSAtomicModel) | ||
|
|
||
|
|
||
| @BaseModel.register("dos") | ||
| class DOSModel(DPModelCommon, DPDOSModel_): | ||
| model_type = "dos" | ||
|
|
||
| def __init__( | ||
| self, | ||
| *args, | ||
| **kwargs, | ||
| ): | ||
| DPModelCommon.__init__(self) | ||
| DPDOSModel_.__init__(self, *args, **kwargs) | ||
anyangml marked this conversation as resolved.
Show resolved
Hide resolved
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
|
|
||
| from deepmd.dpmodel.atomic_model import ( | ||
| DPPolarAtomicModel, | ||
| ) | ||
| from deepmd.dpmodel.model.model import ( | ||
| BaseModel, | ||
| ) | ||
|
|
||
| from .dp_model import ( | ||
| DPModelCommon, | ||
| ) | ||
| from .make_model import ( | ||
| make_model, | ||
| ) | ||
|
|
||
| DPPolarModel_ = make_model(DPPolarAtomicModel) | ||
|
|
||
|
|
||
| @BaseModel.register("polar") | ||
| class PolarModel(DPModelCommon, DPPolarModel_): | ||
| model_type = "polar" | ||
anyangml marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def __init__( | ||
| self, | ||
| *args, | ||
| **kwargs, | ||
| ): | ||
| DPModelCommon.__init__(self) | ||
| DPPolarModel_.__init__(self, *args, **kwargs) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.