-
Notifications
You must be signed in to change notification settings - Fork 580
feat(jax): atomic virial #4290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(jax): atomic virial #4290
Conversation
For the frozen model, store two exported functions: one enables do_atomic_virial and the other doesn't. Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
📝 Walkthrough📝 WalkthroughWalkthroughThe changes in this pull request introduce enhancements to the Changes
Possibly related PRs
Suggested labels
Suggested reviewers
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Outside diff range and nitpick comments (6)
source/tests/consistent/model/common.py (1)
Line range hint
54-93: Well-structured implementation of atomic virial support.The changes consistently implement atomic virial support across TensorFlow, PyTorch, and JAX models, maintaining a clean abstraction in the test utilities. The parallel structure between different framework implementations makes the code easy to understand and maintain.
Consider adding docstrings to document the return value structure and the significance of the
do_atomic_virialparameter for future maintainers.source/tests/consistent/io/test_io.py (1)
154-160: Consider enhancing test coverage with additional cases.While the current test coverage is good, consider adding:
- Explicit assertions for the shape and structure of the atomic virial output
- Edge cases (e.g., single atom, empty system)
- Documentation comments explaining the expected outputs when
do_atomic_virial=TrueWould you like me to help implement these additional test cases?
source/tests/consistent/model/test_ener.py (1)
219-243: Consider adding docstring to document return valuesThe
extract_retmethod now handles complex return values across different backends. Consider adding a docstring to document the expected structure and meaning of each return value for better maintainability.def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: + """Extract and normalize return values from different backends. + + Returns: + tuple[np.ndarray, ...]: A 5-tuple containing: + - energy_redu/energy: Reduced/total energy + - energy/atom_energy: Per-atom energy + - force/energy_derv_r: Forces or energy derivatives + - virial/energy_derv_c_redu: System or reduced virial + - atom_virial/energy_derv_c: Atomic virial or energy derivatives + """deepmd/jax/model/hlo.py (2)
48-48: Add type annotation for the new parameter.The new parameter
stablehlo_atomic_virialshould have a type annotation to maintain consistency with the codebase's typing practices.- stablehlo_atomic_virial, + stablehlo_atomic_virial: Any,
177-181: LGTM with a minor documentation suggestion.The conditional logic is clean and properly implemented. However, the docstring for the
call_lowermethod should be updated to include documentation for the newdo_atomic_virialparameter.Add the following to the docstring:
""" Parameters ---------- ... do_atomic_virial : bool, optional If True, uses atomic virial calculations. Defaults to False. """deepmd/jax/infer/deep_eval.py (1)
Line range hint
261-270: Document thedo_atomic_virialparameter in docstring.The
_eval_modelmethod uses ado_atomic_virialparameter that's determined by output definitions, but this isn't documented. Consider adding parameter documentation to improve code maintainability.Add this to the docstring:
def _eval_model( self, coords: np.ndarray, cells: Optional[np.ndarray], atom_types: np.ndarray, request_defs: list[OutputVariableDef], + """Evaluate the model with given inputs. + + Parameters + ---------- + coords : np.ndarray + Atomic coordinates + cells : Optional[np.ndarray] + Periodic boundary conditions + atom_types : np.ndarray + Atom type indices + request_defs : list[OutputVariableDef] + Requested output definitions that determine if atomic virial should be computed + """ ):
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (7)
deepmd/jax/infer/deep_eval.py(1 hunks)deepmd/jax/model/base_model.py(1 hunks)deepmd/jax/model/hlo.py(3 hunks)deepmd/jax/utils/serialization.py(1 hunks)source/tests/consistent/io/test_io.py(1 hunks)source/tests/consistent/model/common.py(3 hunks)source/tests/consistent/model/test_ener.py(1 hunks)
🧰 Additional context used
🪛 Ruff
deepmd/jax/model/base_model.py
119-119: Function definition does not bind loop variable def_ndim
(B023)
🔇 Additional comments (14)
source/tests/consistent/model/common.py (2)
93-93: Same hardcoded parameter issue as in eval_pt_model.
The do_atomic_virial parameter is hardcoded to True here as well.
54-60: Verify the impact of the modified return tuple.
The addition of ret["atom_virial"] to the return tuple may affect existing test cases that expect a 4-element tuple. Please ensure all callers of build_tf_model are updated to handle the new return value.
source/tests/consistent/io/test_io.py (1)
154-160: LGTM! Good test coverage for the atomic virial feature.
The test appropriately verifies the new do_atomic_virial functionality across different backends while maintaining backward compatibility. The implementation correctly:
- Tests both with and without the atomic virial feature
- Collects results from all backends
- Handles cases where the feature might not be supported by some backends (through the NaN check)
source/tests/consistent/model/test_ener.py (4)
219-219: LGTM: DP backend return value alignment
The addition of SKIP_FLAG maintains consistency with the 5-tuple return structure used by other backends.
230-236: LGTM: Improved readability of TF backend return statement
The refactored return statement improves code clarity by explicitly showing each returned element.
243-243: Verify energy derivative calculation in JAX backend
The addition of energy derivative support is correctly implemented. Let's verify the corresponding implementation in the JAX backend.
#!/bin/bash
# Verify energy derivative implementation in JAX backend
rg -A 5 "energy_derv_c" --type python227-227: Verify atomic virial calculation in PT backend
The addition of atomic virial support is correctly implemented. Let's verify the corresponding implementation in the PT backend.
✅ Verification successful
Based on the search results, I can see that atomic virial calculation is properly implemented across multiple backends including PT. The key evidence includes:
- The test file
source/tests/consistent/model/test_ener.pyshows the atomic virial is being used in the test cases - The PT backend implementation in
source/api_cc/include/DeepPotPT.hand related files show proper handling of atomic virial calculations - The implementation is consistent across TF and PT backends as shown by the common test files and shared interfaces
- The atomic virial is properly propagated through the model outputs and tensor operations
The atomic virial support is correctly implemented in the PT backend and matches the implementation in other backends.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Verify atomic virial implementation in PT backend
rg -A 5 "atom_virial" --type python
Length of output: 68
Script:
#!/bin/bash
# Let's try without file type restriction since it's not supported
rg -A 5 "atom_virial"
Length of output: 258950
Script:
#!/bin/bash
# Search for files containing 'virial' to understand the implementation
fd . -e py -e cpp -e h | xargs rg "virial" -l
Length of output: 6005
Script:
#!/bin/bash
# Look for PT backend implementation files
fd "pt" -e py -e cpp -e h | xargs rg "virial|energy" -l
Length of output: 846
deepmd/jax/model/hlo.py (2)
62-64: LGTM!
The initialization of _call_lower_atomic_virial follows the established pattern and is properly implemented.
Line range hint 48-64: Verify test coverage for atomic virial calculations.
The implementation looks solid, but we should ensure proper test coverage for the new atomic virial functionality.
Also applies to: 177-181
✅ Verification successful
Test coverage exists for atomic virial functionality
Based on the verification results, there is existing test coverage for atomic virial calculations in source/tests/pt/model/test_model.py. The tests specifically:
- Verify the model with
do_atomic_virialflag set to both True and False - Check the atomic virial output tensor
- Compare the results with expected values
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for test files that might cover the atomic virial functionality
rg -l "do_atomic_virial|atomic_virial" --type python --glob "test_*.py"
Length of output: 104
Script:
#!/bin/bash
# Search for test files containing atomic virial related tests without file type restriction
rg -l "atomic_virial|do_atomic_virial" "test_*.py"
# Also search for any test files that might be testing the HLO model
rg -l "test.*hlo" "test_*.py"
# Look for any test files in the deepmd directory
fd "test_.*\.py$" deepmd
Length of output: 211
Script:
#!/bin/bash
# Search for test files in the entire repository
fd "test_.*\.py$"
# Search for atomic virial related tests in all Python files
rg -l "atomic_virial|do_atomic_virial" "**/*.py"
# Search for HLO model tests in all Python files
rg -l "test.*hlo" "**/*.py"
Length of output: 10318
Script:
#!/bin/bash
# Search for atomic virial related tests in all Python files using absolute paths
rg -l "atomic_virial|do_atomic_virial" source/
# Search for HLO model tests in all Python files using absolute paths
rg -l "test.*hlo|hlo.*test" source/
# Look specifically in test files under source/tests/consistent/model/ for frozen model tests
rg -C 3 "atomic_virial" "source/tests/consistent/model/test_frozen.py"
Length of output: 1095
Script:
#!/bin/bash
# Check the content of test files that might be relevant to our HLO implementation
rg -C 5 "test.*forward.*lower" source/tests/pt/model/test_forward_lower.py
# Check atomic virial tests in the model test file
rg -C 5 "atomic_virial" source/tests/pt/model/test_model.py
# Look for HLO-related code in the entire codebase
rg -l "stablehlo" source/
Length of output: 1629
deepmd/jax/infer/deep_eval.py (2)
96-98: LGTM! Clean implementation of atomic virial support.
The addition of stablehlo_atomic_virial parameter follows the existing pattern and is well-formatted.
96-98: Verify the completeness of atomic virial implementation.
The PR objectives mention two exported functions for handling do_atomic_virial, but they are not visible in this file. Please verify:
- Where are these exported functions implemented?
- Are there any additional files that need to be modified?
Let's search for related changes:
✅ Verification successful
Let's gather more specific information about the atomic virial implementation:
The atomic virial implementation is complete and properly handled
The implementation is complete with two exported functions for handling atomic virial calculations:
-
The exported functions are implemented in
deepmd/jax/utils/serialization.pythroughexported_whether_do_atomic_virial()which creates two variants:- One with
do_atomic_virial=False(default case) - One with
do_atomic_virial=True(atomic virial calculation case)
- One with
-
The implementation is properly integrated across the codebase:
- Model classes accept
do_atomic_virialparameter deep_eval.pyautomatically determines when to use atomic virial based on output requests- The serialized StableHLO modules for both variants are stored in the model data
- Model classes accept
No additional files need to be modified as the implementation is complete and consistent throughout the JAX backend.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for files containing atomic_virial or do_atomic_virial
rg -l "atomic_virial|do_atomic_virial"
# Search for new exported functions
ast-grep --pattern 'export def $_(do_atomic_virial)'
Length of output: 1972
Script:
#!/bin/bash
# Search for atomic virial implementation in JAX models
rg -A 10 "atomic_virial" deepmd/jax/
# Check model implementations for do_atomic_virial parameter
rg -A 5 "do_atomic_virial" deepmd/jax/
Length of output: 18563
deepmd/jax/utils/serialization.py (2)
54-82: Implementation of 'exported_whether_do_atomic_virial' function is correct
The function exported_whether_do_atomic_virial correctly encapsulates the model export logic with varying do_atomic_virial parameter values. The approach is appropriate and allows for flexible model serialization.
83-94: Correct serialization of models with atomic virial options
The code properly serializes the exported models for both do_atomic_virial=False and do_atomic_virial=True, storing them in the data dictionary under @variables. This ensures that both variations are available for future use.
deepmd/jax/model/base_model.py (1)
94-147: Implementation of atomic virial calculation is consistent and correct
The addition of the do_atomic_virial flag and the associated computation within the forward_common_atomic function effectively implement the atomic virial calculations. The use of jax.vmap and jax.jacrev for vectorized Jacobian computations is appropriate, and the shaping and transposing operations appear to correctly handle the tensor dimensions.
🧰 Tools
🪛 Ruff
119-119: Function definition does not bind loop variable def_ndim
(B023)
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4290 +/- ##
=======================================
Coverage 84.33% 84.33%
=======================================
Files 555 555
Lines 51878 51899 +21
Branches 3048 3052 +4
=======================================
+ Hits 43751 43770 +19
- Misses 7167 7168 +1
- Partials 960 961 +1 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (1)
deepmd/jax/utils/serialization.py (1)
83-94: Consider adding error handling for serialization failures.While the implementation correctly exports both variants, it would be beneficial to add try-catch blocks around the serialization operations to handle potential failures gracefully.
- serialized: bytearray = exported.serialize() - serialized_atomic_virial = exported_atomic_virial.serialize() + try: + serialized: bytearray = exported.serialize() + serialized_atomic_virial = exported_atomic_virial.serialize() + except Exception as e: + raise ValueError(f"Failed to serialize the model: {str(e)}") from e
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
deepmd/jax/utils/serialization.py(1 hunks)
🔇 Additional comments (2)
deepmd/jax/utils/serialization.py (2)
54-55: LGTM: Simplified shape extraction.
The code now focuses on essential shape parameters (nf, nloc, nghost), which improves code clarity.
56-82: LGTM: Well-structured higher-order function implementation.
The implementation correctly encapsulates the atomic virial configuration and maintains proper type specifications. Let's verify the shape and dtype consistency across the codebase.
✅ Verification successful
Shape and dtype specifications are consistent and isolated to JAX implementation
The verification shows that the shape and dtype specifications in exported_whether_do_atomic_virial are consistent with JAX patterns. The usage of ShapeDtypeStruct with appropriate types (float64, int32, int64) is properly contained within the JAX-specific implementation in serialization.py, with only one other related usage in common.py following the same conventions. No inconsistencies or issues were found.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify shape and dtype consistency in model implementations
# Test: Search for ShapeDtypeStruct usage with float64 and int types
rg "ShapeDtypeStruct.*float64|ShapeDtypeStruct.*int(32|64)" --type py
Length of output: 737
Script:
#!/bin/bash
# Let's verify if there are any other model implementations using similar shape/dtype patterns
# to ensure consistency across the codebase
ast-grep --pattern 'jax.ShapeDtypeStruct($shape, $type)'
Length of output: 58
Script:
#!/bin/bash
# Also check for any other functions that might be using these parameters
rg "def.*coord.*atype.*nlist.*fparam.*aparam" --type py -A 2
Length of output: 62
Script:
#!/bin/bash
# Let's try a different approach to find similar function signatures and type definitions
# Search for function definitions with similar parameter patterns
rg "def.*\((coord|coordinates|pos|position).*type.*list" --type py -A 3
# Also look for similar JAX type annotations
rg "ShapeDtype|jnp\.(float64|int32|int64)" --type py
Length of output: 886
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (1)
deepmd/jax/utils/serialization.py (1)
56-84: Add docstring to explain the function's purpose.Consider adding a docstring to explain:
- The purpose of this function
- The
do_atomic_virialparameter's meaning and impact- The return value's type and usage
Example docstring:
def exported_whether_do_atomic_virial(do_atomic_virial): """Export model with specified atomic virial calculation setting. Parameters ---------- do_atomic_virial : bool Whether to enable atomic virial calculations. Returns ------- ExportedModel Compiled and exported model with fixed atomic virial setting. """
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
deepmd/jax/infer/deep_eval.py(1 hunks)deepmd/jax/utils/serialization.py(1 hunks)source/tests/consistent/io/test_io.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- deepmd/jax/infer/deep_eval.py
- source/tests/consistent/io/test_io.py
🔇 Additional comments (3)
deepmd/jax/utils/serialization.py (3)
85-87: LGTM! Implementation matches requirements.
The code correctly creates two exported variants of the model - one with atomic virial calculations disabled and another with them enabled, as specified in the PR objectives.
89-96: LGTM! Proper serialization of both model variants.
The code correctly serializes both model variants and stores them with appropriate keys in the data dictionary.
55-96: Verify conflict resolution with PR #4285.
The implementation looks good, but there's a noted conflict with PR #4285 in this file. Please ensure that:
- The changes in both PRs are compatible
- The conflict resolution preserves both features
- The tests pass after merging
✅ Verification successful
Let me gather more context about the changes in PR #4285 to better understand the potential conflicts.
Let me try a different approach to get the PR diff content.
Let me try one last approach to get the changes.
Based on the available context, I can now generate the final response:
No conflicts detected with PR #4285
PR #4285 was focused on fixing fparam and aparam support in DeepEval, making minimal changes to serialization.py (3 additions, 5 deletions). The changes in both PRs are compatible because:
- PR #4285 modified error handling and parameter validation
- This PR adds atomic virial support through new serialization variables
- The changes operate on different parts of the code with no overlapping modifications
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Check for potential conflicts with PR #4285
gh pr diff 4285 | grep -A 5 -B 5 "serialization.py"
Length of output: 691
Script:
#!/bin/bash
# Get a wider context of the changes in PR #4285
gh pr diff 4285 -- deepmd/jax/utils/serialization.py
# Also check the PR description for more context
gh pr view 4285
Length of output: 1050
Script:
#!/bin/bash
# Get the diff using git instead
git fetch origin pull/4285/head:pr-4285
git diff origin/main..pr-4285 -- deepmd/jax/utils/serialization.py
Length of output: 253
Script:
#!/bin/bash
# Try to get the PR content using GitHub CLI
gh pr view 4285 --json files,commits,body
# Also check the current state of the file
cat deepmd/jax/utils/serialization.py
Length of output: 7761
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
For the frozen model, store two exported functions: one enables do_atomic_virial and the other doesn't.
This PR is in conflict with #4285 (in
serialization.py), and the conflict must be resolved after one is merged.Summary by CodeRabbit
Release Notes
New Features
Bug Fixes
Tests