-
Notifications
You must be signed in to change notification settings - Fork 580
fix(jax): fix compatibility with flax 0.12 #5067
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
…odel" This reverts commit 4932fbe.
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## devel #5067 +/- ##
==========================================
- Coverage 84.33% 84.28% -0.06%
==========================================
Files 709 709
Lines 70435 70547 +112
Branches 3618 3619 +1
==========================================
+ Hits 59402 59458 +56
- Misses 9867 9922 +55
- Partials 1166 1167 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
📝 WalkthroughWalkthroughThis PR adds version-conditional wrapping of JAX attributes with Flax NNX data structures across descriptor, atomic model, fitting, and utility modules. When Flax version >= 0.12.0, specific attributes are transformed using Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25–30 minutes Areas requiring extra attention:
Possibly related PRs
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Tip 📝 Customizable high-level summaries are now available in beta!You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.
Example instruction:
Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (18)
deepmd/jax/atomic_model/base_atomic_model.py(2 hunks)deepmd/jax/atomic_model/linear_atomic_model.py(3 hunks)deepmd/jax/atomic_model/pairtab_atomic_model.py(3 hunks)deepmd/jax/descriptor/dpa1.py(4 hunks)deepmd/jax/descriptor/dpa2.py(3 hunks)deepmd/jax/descriptor/dpa3.py(3 hunks)deepmd/jax/descriptor/hybrid.py(3 hunks)deepmd/jax/descriptor/repflows.py(4 hunks)deepmd/jax/descriptor/repformers.py(4 hunks)deepmd/jax/descriptor/se_e2_a.py(3 hunks)deepmd/jax/descriptor/se_e2_r.py(3 hunks)deepmd/jax/descriptor/se_t.py(3 hunks)deepmd/jax/descriptor/se_t_tebd.py(3 hunks)deepmd/jax/env.py(2 hunks)deepmd/jax/fitting/fitting.py(4 hunks)deepmd/jax/utils/exclude_mask.py(3 hunks)deepmd/jax/utils/network.py(4 hunks)deepmd/jax/utils/type_embed.py(2 hunks)
🧰 Additional context used
🧠 Learnings (4)
📚 Learning: 2024-10-26T02:09:01.365Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4258
File: deepmd/jax/utils/neighbor_stat.py:98-101
Timestamp: 2024-10-26T02:09:01.365Z
Learning: The function `to_jax_array` in `deepmd/jax/common.py` can handle `None` values, so it's safe to pass `None` to it without additional checks.
Applied to files:
deepmd/jax/atomic_model/linear_atomic_model.pydeepmd/jax/atomic_model/pairtab_atomic_model.pydeepmd/jax/descriptor/se_t_tebd.pydeepmd/jax/descriptor/dpa1.pydeepmd/jax/descriptor/dpa3.pydeepmd/jax/fitting/fitting.pydeepmd/jax/descriptor/dpa2.pydeepmd/jax/descriptor/repflows.pydeepmd/jax/atomic_model/base_atomic_model.pydeepmd/jax/descriptor/se_t.pydeepmd/jax/descriptor/repformers.pydeepmd/jax/utils/network.pydeepmd/jax/descriptor/se_e2_r.pydeepmd/jax/descriptor/se_e2_a.py
📚 Learning: 2024-10-30T20:08:12.531Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4284
File: deepmd/jax/__init__.py:8-8
Timestamp: 2024-10-30T20:08:12.531Z
Learning: In the DeepMD project, entry points like `deepmd.jax` may be registered in external projects, so their absence in the local configuration files is acceptable.
Applied to files:
deepmd/jax/atomic_model/linear_atomic_model.pydeepmd/jax/atomic_model/pairtab_atomic_model.pydeepmd/jax/descriptor/se_t_tebd.pydeepmd/jax/descriptor/dpa1.pydeepmd/jax/fitting/fitting.pydeepmd/jax/env.pydeepmd/jax/descriptor/dpa2.pydeepmd/jax/descriptor/repflows.pydeepmd/jax/atomic_model/base_atomic_model.pydeepmd/jax/descriptor/se_t.pydeepmd/jax/descriptor/repformers.pydeepmd/jax/utils/network.pydeepmd/jax/descriptor/se_e2_r.pydeepmd/jax/descriptor/se_e2_a.py
📚 Learning: 2024-11-23T00:01:06.984Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4406
File: deepmd/dpmodel/array_api.py:51-53
Timestamp: 2024-11-23T00:01:06.984Z
Learning: In `deepmd/dpmodel/array_api.py`, the `__array_api_version__` attribute is guaranteed by the Array API standard to always be present, so error handling for its absence is not required.
Applied to files:
deepmd/jax/descriptor/dpa1.pydeepmd/jax/descriptor/dpa3.pydeepmd/jax/atomic_model/base_atomic_model.pydeepmd/jax/descriptor/se_t.pydeepmd/jax/descriptor/hybrid.pydeepmd/jax/descriptor/repformers.pydeepmd/jax/utils/network.pydeepmd/jax/descriptor/se_e2_r.pydeepmd/jax/descriptor/se_e2_a.py
📚 Learning: 2024-10-30T03:16:31.013Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4278
File: source/tests/array_api_strict/fitting/fitting.py:58-66
Timestamp: 2024-10-30T03:16:31.013Z
Learning: The attributes `scale` and `constant_matrix` are specific to the `PolarFittingNet` class and should have their special handling within this class rather than in the general `setattr_for_general_fitting` function.
Applied to files:
deepmd/jax/fitting/fitting.py
🧬 Code graph analysis (7)
deepmd/jax/atomic_model/linear_atomic_model.py (1)
deepmd/jax/atomic_model/pairtab_atomic_model.py (1)
PairTabAtomicModel(31-58)
deepmd/jax/utils/type_embed.py (2)
deepmd/dpmodel/utils/type_embed.py (1)
TypeEmbedNet(30-221)deepmd/jax/common.py (5)
ArrayAPIVariable(86-97)flax_module(45-83)to_jax_array(20-20)to_jax_array(24-24)to_jax_array(27-42)
deepmd/jax/descriptor/dpa2.py (2)
deepmd/jax/descriptor/se_t_tebd.py (1)
DescrptBlockSeTTebd(38-55)deepmd/jax/utils/type_embed.py (1)
TypeEmbedNet(26-36)
deepmd/jax/descriptor/repflows.py (1)
deepmd/jax/common.py (1)
ArrayAPIVariable(86-97)
deepmd/jax/descriptor/hybrid.py (2)
source/tests/consistent/descriptor/test_hybrid.py (1)
data(50-78)deepmd/dpmodel/descriptor/hybrid.py (2)
deserialize(379-389)serialize(370-376)
deepmd/jax/descriptor/se_e2_r.py (1)
deepmd/jax/utils/network.py (1)
NetworkCollection(78-88)
deepmd/jax/descriptor/se_e2_a.py (1)
deepmd/jax/utils/network.py (1)
NetworkCollection(78-88)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: CodeQL analysis (python)
- GitHub Check: Agent
🔇 Additional comments (31)
deepmd/jax/utils/type_embed.py (2)
6-8: Imports look good.The addition of version checking imports is appropriate for the Flax 0.12 compatibility changes.
Also applies to: 16-19
32-33: nnx.data(None) is valid and correct for Flax 0.12+According to Flax 0.12 documentation,
nnx.data(None)is the correct way to set an attribute's default value toNonewhile registering it as pytree data. The pattern is explicitly documented in the Flax 0.12 release notes regarding stricter pytree attribute rules. WrappingNoneis necessary because callingnnx.data()without arguments uses a Missing sentinel by default, whereasnnx.data(None)explicitly sets the value toNoneand includes it in the pytree. This usage is consistent with the codebase's approach and safe to use.deepmd/jax/descriptor/dpa3.py (1)
6-8: LGTM - Consistent Flax 0.12 compatibility pattern.The version-gated wrapping with
nnx.data()follows the same pattern as other files in this PR. The implementation is consistent for handlingmeanandstddevattributes.Also applies to: 22-25, 39-40
deepmd/jax/descriptor/se_t.py (1)
6-8: LGTM - Consistent Flax 0.12 compatibility pattern.The version-gated wrapping is applied consistently for the
dstdanddavgattributes, matching the pattern across the PR.Also applies to: 19-22, 41-42
deepmd/jax/descriptor/se_e2_a.py (2)
6-8: LGTM - Imports are consistent with the PR pattern.Also applies to: 19-22
40-41: LGTM - Flax 0.12 compatibility applied to multiple attributes.The version-gated wrapping is correctly applied to both
dstd/davgandembeddingsattributes, extending the pattern consistently.Also applies to: 45-46
deepmd/jax/atomic_model/pairtab_atomic_model.py (2)
7-9: LGTM - Consistent Flax 0.12 compatibility for PairTab model.The version-gated wrapping is correctly applied to
tab_infoandtab_dataattributes, following the established pattern.Also applies to: 22-26, 38-39
54-54: Clarify the purpose of stop_gradient on nlist.Line 54 wraps
nlistwithjax.lax.stop_gradient(), which appears unrelated to the Flax 0.12 compatibility changes. This prevents gradient computation through the neighbor list.Is this change:
- A separate bug fix?
- Required for Flax 0.12 compatibility?
- A performance optimization?
Please clarify the reason for this change, as it's not mentioned in the PR description and doesn't follow the pattern seen in other files.
deepmd/jax/descriptor/se_t_tebd.py (1)
6-8: LGTM - Consistent Flax 0.12 compatibility pattern.The version-gated wrapping is applied consistently to
meanandstddevattributes in theDescrptBlockSeTTebdclass.Also applies to: 22-25, 44-45
deepmd/jax/atomic_model/base_atomic_model.py (1)
6-8: LGTM - Flax 0.12 compatibility applied to base helper function.The version-gated wrapping is correctly implemented in the
base_atomic_model_set_attrhelper function forout_biasandout_stdattributes. Since this is a shared helper function (imported bypairtab_atomic_model.pyat line 15), this ensures consistent behavior across atomic model implementations.Also applies to: 14-17, 29-30
deepmd/jax/env.py (1)
8-8: Unable to complete full verification — repository cloning blocked by sandbox limitations.However, based on the web search results I gathered:
Verified facts:
flax.__version__does return a PEP 440-compatible string ✓nnx.data()is confirmed as a field descriptor in Flax 0.12+ for marking pytree "data" attributes ✓What I cannot verify without code access:
- How
flax_versionis actually used at line 23 indeepmd/jax/env.py- Whether the version parsing logic correctly handles
packaging.version.Version(flax_version)- The specific usage pattern of
nnx.data()wrapper in this PRRecommendation:
The review comment's verification requests are technically sound. Both Flax APIs referenced exist and work as expected per official documentation. However, without seeing the actual implementation context, I cannot confirm whether the code correctly applies these APIs.
Final rewritten review comment:
Verify version parsing and nnx.data() usage in Flax 0.12 compatibility logic.
The import of
flax.__version__is correct. Flax maintains PEP 440-compatible version strings that work withpackaging.version.Version(). However, please verify at line 23:
- That version comparison logic handles
flax.__version__correctly without exceptions- That
nnx.data()field descriptor is used correctly for Flax 0.12+ compatibility (as a field marker at assignment, not a wrapper function)Also applies to: 23-23
deepmd/jax/descriptor/se_e2_r.py (1)
40-41: Verify that wrapping None values with nnx.data(None) is necessary for Flax 0.12.0.The code wraps
Nonevalues withnnx.data(None)whenflax_version >= 0.12.0. This pattern appears multiple times across the codebase. Please confirm this is intentional and required by Flax 0.12.0's NNX module system, as wrappingNoneis unusual.Also applies to: 45-46
deepmd/jax/utils/network.py (2)
63-66: Verify the wrapping pattern for layers attribute.
NativeNetwrapslayerswithnnx.List(value)without wrapping individual items, whileNetworkCollection(lines 85-88) wraps each item withnnx.data(item)before wrapping the list. Please confirm whether the layers items should also be wrapped withnnx.data()for consistency, or if this difference is intentional based on the content type.
85-88: LGTM!The wrapping pattern correctly applies
nnx.data()to each item before wrapping the list withnnx.List()for Flax 0.12.0 compatibility.deepmd/jax/atomic_model/linear_atomic_model.py (3)
40-43: LGTM!The mapping_list wrapping correctly applies the pattern of wrapping each item with
nnx.data()withinnnx.List()for Flax 0.12.0 compatibility.
44-47: Verify that test behavior is unaffected by the zbl_weight no-op handling.The
zbl_weightattribute is now discarded (early return) instead of being wrapped, to avoid Flax trace mutation errors. Since the comment indicates this attribute is only used in tests, please confirm that all relevant tests still pass with this change.
48-54: LGTM!The models deserialization and wrapping correctly applies the Flax 0.12.0 compatibility pattern.
deepmd/jax/descriptor/repformers.py (2)
42-61: LGTM!The attribute wrapping for
mean,stddev, andlayerscorrectly implements the Flax 0.12.0 compatibility pattern, consistent with other descriptor modules.
98-133: LGTM!The attribute wrapping in
RepformerLayercorrectly implements the Flax 0.12.0 compatibility pattern for both None values and lists, consistent with the broader codebase changes.deepmd/jax/fitting/fitting.py (2)
39-57: LGTM!The centralized attribute handling in
setattr_for_general_fittingcorrectly implements the Flax 0.12.0 compatibility pattern and is appropriately reused across multiple fitting classes.
94-106: LGTM!The PolarFittingNet correctly applies the Flax 0.12.0 compatibility pattern to its specific attributes (
scale,constant_matrix) in addition to the general fitting attributes.deepmd/jax/descriptor/hybrid.py (1)
28-38: LGTM!The list wrapping for both
nlist_cut_idxanddescrpt_listcorrectly implements the Flax 0.12.0 compatibility pattern.deepmd/jax/descriptor/dpa1.py (2)
63-71: LGTM!The
attention_layerslist wrapping correctly implements the Flax 0.12.0 compatibility pattern.
75-96: LGTM!The attribute wrapping in
DescrptBlockSeAttencorrectly implements the Flax 0.12.0 compatibility pattern for both mean/stddev and embeddings attributes.deepmd/jax/descriptor/dpa2.py (1)
45-74: LGTM!The attribute wrapping in
DescrptDPA2.__setattr__correctly implements the Flax 0.12.0 compatibility pattern. The handling ofg1_shape_tranformandtebd_transformappropriately covers all cases: None values, NativeLayerDP, and IdentityDP instances.deepmd/jax/utils/exclude_mask.py (2)
6-20: LGTM! Imports support version-conditional behavior.The imports of
Version,flax_version, andnnxare appropriate for implementing version-conditional wrapping logic.
23-44: The code pattern is correct; no changes needed.The web search confirms that wrapping
Nonewithnnx.data(None)is the intended and correct pattern for flax 0.12.0. In flax 0.12.0,nnx.data(value)is the required way to mark attributes as pytree "data" (rather than static), even when the value isNone. This prevents strict pytree attribute handling errors introduced in that version. The conditional logic is sound: whenvalue is not None, useArrayAPIVariable; otherwise, for flax >= 0.12.0, usennx.data(None)to properly annotate optional/cached fields.deepmd/jax/descriptor/repflows.py (4)
6-22: LGTM! Imports are consistent with version-conditional pattern.The imports support the version-conditional wrapping logic needed for flax 0.12.0 compatibility.
34-39: Consistent pattern for None value handling.The pattern matches the approach in
exclude_mask.pyfor handling None values with flax 0.12.0.
71-74: Consistent None value handling.The pattern for handling None values is consistent with the approach used throughout this PR.
75-78: Based on the web search results, I can now verify the concern. The search confirmed that:
nnx.data()is the correct pattern for wrappingnnx.Variableinstances - The documentation explicitly showsself.v = nnx.data(nnx.Param(jnp.ones(3)))as a valid pattern to wrap Variables and store them as pytree data.Flax 0.12.0 enforces stricter pytree handling - Containers with JAX Arrays must be marked with
nnx.data()or usennx.List/nnx.Dictfor mixed data/static contents.The code pattern is correct - Wrapping
ArrayAPIVariableinstances (which extendnnx.Variable) withnnx.data()insidennx.Listis the proper and intended usage pattern for Flax 0.12.0+.The original review comment raised a concern that turned out to be incorrect - the code is following the correct pattern.
The code pattern is correct; no issues found.
The wrapping of
ArrayAPIVariableinstances withnnx.data()insidennx.Listaligns with Flax 0.12.0's design for handling Variables as pytree data. This is the intended pattern documented in the Flax API and required for strict pytree compliance in 0.12.0+.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds compatibility with flax 0.12 by introducing version-specific handling for JAX modules. The changes wrap certain values with nnx.data() and nnx.List() when using flax version 0.12.0 or higher.
- Exports
flax_versionfromdeepmd/jax/env.pyfor version checking - Adds conditional logic across multiple descriptor, fitting, and atomic model files to handle flax 0.12+ compatibility
- Special handling for the
zbl_weightattribute in linear atomic model
Reviewed changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated 27 comments.
Show a summary per file
| File | Description |
|---|---|
| deepmd/jax/env.py | Exports flax_version for version checking across modules |
| deepmd/jax/utils/type_embed.py | Adds flax 0.12 compatibility for type embedding attributes |
| deepmd/jax/utils/network.py | Wraps layers and network collections with nnx.List for flax 0.12+ |
| deepmd/jax/utils/exclude_mask.py | Adds flax 0.12 compatibility for mask attributes |
| deepmd/jax/fitting/fitting.py | Handles fitting network attributes for flax 0.12+ |
| deepmd/jax/descriptor/se_t_tebd.py | Adds compatibility handling for SE-T-TEBD descriptor attributes |
| deepmd/jax/descriptor/se_t.py | Adds compatibility handling for SE-T descriptor attributes |
| deepmd/jax/descriptor/se_e2_r.py | Adds compatibility handling for SE-R descriptor attributes |
| deepmd/jax/descriptor/se_e2_a.py | Adds compatibility handling for SE-A descriptor attributes |
| deepmd/jax/descriptor/repformers.py | Extensive compatibility handling for Repformer layers and attributes |
| deepmd/jax/descriptor/repflows.py | Adds compatibility handling for Repflow layers and attributes |
| deepmd/jax/descriptor/hybrid.py | Wraps descriptor lists with nnx.List for flax 0.12+ |
| deepmd/jax/descriptor/dpa3.py | Adds compatibility handling for DPA3 descriptor attributes |
| deepmd/jax/descriptor/dpa2.py | Adds compatibility handling for DPA2 descriptor attributes |
| deepmd/jax/descriptor/dpa1.py | Adds compatibility handling for DPA1 attention layers and embeddings |
| deepmd/jax/atomic_model/pairtab_atomic_model.py | Adds compatibility handling for pair table atomic model attributes |
| deepmd/jax/atomic_model/linear_atomic_model.py | Special handling for zbl_weight and wraps model lists with nnx.List |
| deepmd/jax/atomic_model/base_atomic_model.py | Adds compatibility handling for base atomic model attributes |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Since the latest TF and JAX have not been compatible with each other, I keep the old JAX and flax version in the CI.
Summary by CodeRabbit
Release Notes
✏️ Tip: You can customize this high-level summary in your review settings.