Skip to content

Conversation

@njzjz
Copy link
Member

@njzjz njzjz commented Nov 27, 2025

Since the latest TF and JAX have not been compatible with each other, I keep the old JAX and flax version in the CI.

Summary by CodeRabbit

Release Notes

  • Chores
    • Updated internal handling across model and descriptor modules to support Flax 0.12.0 and later versions with conditional runtime behavior based on detected Flax version.

✏️ Tip: You can customize this high-level summary in your review settings.

@codecov
Copy link

codecov bot commented Nov 29, 2025

Codecov Report

❌ Patch coverage is 52.17391% with 55 lines in your changes missing coverage. Please review.
✅ Project coverage is 84.28%. Comparing base (1ee33c8) to head (10b378e).
⚠️ Report is 3 commits behind head on devel.

Files with missing lines Patch % Lines
deepmd/jax/descriptor/repformers.py 50.00% 10 Missing ⚠️
deepmd/jax/descriptor/repflows.py 50.00% 5 Missing ⚠️
deepmd/jax/descriptor/dpa1.py 50.00% 4 Missing ⚠️
deepmd/jax/descriptor/dpa2.py 50.00% 4 Missing ⚠️
deepmd/jax/descriptor/se_e2_a.py 33.33% 4 Missing ⚠️
deepmd/jax/descriptor/se_e2_r.py 33.33% 4 Missing ⚠️
deepmd/jax/utils/exclude_mask.py 33.33% 4 Missing ⚠️
deepmd/jax/fitting/fitting.py 50.00% 3 Missing ⚠️
deepmd/jax/atomic_model/base_atomic_model.py 50.00% 2 Missing ⚠️
deepmd/jax/atomic_model/linear_atomic_model.py 66.66% 2 Missing ⚠️
... and 7 more
Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #5067      +/-   ##
==========================================
- Coverage   84.33%   84.28%   -0.06%     
==========================================
  Files         709      709              
  Lines       70435    70547     +112     
  Branches     3618     3619       +1     
==========================================
+ Hits        59402    59458      +56     
- Misses       9867     9922      +55     
- Partials     1166     1167       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@njzjz njzjz marked this pull request as ready for review November 29, 2025 08:04
Copilot finished reviewing on behalf of njzjz November 29, 2025 08:07
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 29, 2025

📝 Walkthrough

Walkthrough

This PR adds version-conditional wrapping of JAX attributes with Flax NNX data structures across descriptor, atomic model, fitting, and utility modules. When Flax version >= 0.12.0, specific attributes are transformed using nnx.data() and nnx.List() wrappers. The flax_version identifier is exported from the environment module for centralized version checking.

Changes

Cohort / File(s) Change Summary
JAX Environment Export
deepmd/jax/env.py
Exports flax_version from flax.__version__ as a public symbol for version-aware conditional logic across modules.
JAX Descriptors — Version-gated NNX Wrapping
deepmd/jax/descriptor/dpa1.py, deepmd/jax/descriptor/dpa2.py, deepmd/jax/descriptor/dpa3.py, deepmd/jax/descriptor/hybrid.py, deepmd/jax/descriptor/repflows.py, deepmd/jax/descriptor/repformers.py, deepmd/jax/descriptor/se_e2_a.py, deepmd/jax/descriptor/se_e2_r.py, deepmd/jax/descriptor/se_t.py, deepmd/jax/descriptor/se_t_tebd.py
Adds Version import and flax version/nnx imports. Conditionally wraps attributes (mean/stddev, layers, embeddings, etc.) with nnx.data() or nnx.List() when Flax version >= 0.12.0 in __setattr__ paths.
JAX Atomic Models — Version-gated NNX Wrapping
deepmd/jax/atomic_model/base_atomic_model.py, deepmd/jax/atomic_model/linear_atomic_model.py, deepmd/jax/atomic_model/pairtab_atomic_model.py
Adds Version import and flax version/nnx imports. Conditionally wraps output attributes and model lists with nnx.data() or nnx.List() when Flax version >= 0.12.0. Updates forward pass in pairtab to wrap nlist with jax.lax.stop_gradient().
JAX Utilities — NNX Wrapping & New Classes
deepmd/jax/utils/network.py, deepmd/jax/utils/type_embed.py
Adds __setattr__ overrides in NativeNet and NetworkCollection to conditionally wrap layers and networks with nnx.List() for Flax >= 0.12.0. TypeEmbedNet adds version-gated nnx.data() wrapping for attributes.
JAX Utilities — Exclude Masks (New Classes)
deepmd/jax/utils/exclude_mask.py
Introduces new AtomExcludeMask and PairExcludeMask classes extending DP counterparts. Adds version-gated nnx.data() wrapping for the type_mask attribute when Flax version >= 0.12.0.
JAX Fitting — Version-gated NNX Wrapping
deepmd/jax/fitting/fitting.py
Adds Version import and flax version/nnx imports. Conditionally wraps fitted attributes (scale, constant_matrix, etc.) with nnx.data() when Flax version >= 0.12.0 in setattr_for_general_fitting.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25–30 minutes

Areas requiring extra attention:

  • Consistency of version check pattern across all 18+ modified files; verify Version(flax_version) >= Version("0.12.0") logic is applied uniformly
  • New classes AtomExcludeMask and PairExcludeMask in exclude_mask.py to confirm inheritance and override correctness
  • __setattr__ method additions in network.py (NativeNet, NetworkCollection) and their integration with parent class behavior
  • Conditional wrapping in pairtab_atomic_model.py forward path using jax.lax.stop_gradient(nlist) to verify control flow intent
  • Public export of flax_version in env.py and its usage across the codebase

Possibly related PRs

  • feat(jax/array-api): se_e3 #4286: Modifies JAX descriptor DescrptSeT (deepmd/jax/descriptor/se_t.py) with similar attribute-handling changes for consistency.
  • feat(jax): zbl #4301: Updates JAX atomic model files (linear_atomic_model.py, pairtab_atomic_model.py) with overlapping __setattr__ and forward-pass modifications.
  • feat(jax/array-api): dpa1 #4160: Modifies deepmd/jax/utils/network.py to add custom parameter/list wrapping behavior for network attributes.

Suggested reviewers

  • wanghan-iapcm
  • iProzd
  • anyangml

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'fix(jax): fix compatibility with flax 0.12' accurately describes the main purpose of the PR - addressing compatibility issues with Flax 0.12 by adding version-gated conditional wrapping of values throughout the JAX codebase.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

📝 Customizable high-level summaries are now available in beta!

You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.

  • Provide your own instructions using the high_level_summary_instructions setting.
  • Format the summary however you like (bullet lists, tables, multi-section layouts, contributor stats, etc.).
  • Use high_level_summary_in_walkthrough to move the summary from the description to the walkthrough section.

Example instruction:

"Divide the high-level summary into five sections:

  1. 📝 Description — Summarize the main change in 50–60 words, explaining what was done.
  2. 📓 References — List relevant issues, discussions, documentation, or related PRs.
  3. 📦 Dependencies & Requirements — Mention any new/updated dependencies, environment variable changes, or configuration updates.
  4. 📊 Contributor Summary — Include a Markdown table showing contributions:
    | Contributor | Lines Added | Lines Removed | Files Changed |
  5. ✔️ Additional Notes — Add any extra reviewer context.
    Keep each section concise (under 200 words) and use bullet or numbered lists for clarity."

Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1ee33c8 and 10b378e.

📒 Files selected for processing (18)
  • deepmd/jax/atomic_model/base_atomic_model.py (2 hunks)
  • deepmd/jax/atomic_model/linear_atomic_model.py (3 hunks)
  • deepmd/jax/atomic_model/pairtab_atomic_model.py (3 hunks)
  • deepmd/jax/descriptor/dpa1.py (4 hunks)
  • deepmd/jax/descriptor/dpa2.py (3 hunks)
  • deepmd/jax/descriptor/dpa3.py (3 hunks)
  • deepmd/jax/descriptor/hybrid.py (3 hunks)
  • deepmd/jax/descriptor/repflows.py (4 hunks)
  • deepmd/jax/descriptor/repformers.py (4 hunks)
  • deepmd/jax/descriptor/se_e2_a.py (3 hunks)
  • deepmd/jax/descriptor/se_e2_r.py (3 hunks)
  • deepmd/jax/descriptor/se_t.py (3 hunks)
  • deepmd/jax/descriptor/se_t_tebd.py (3 hunks)
  • deepmd/jax/env.py (2 hunks)
  • deepmd/jax/fitting/fitting.py (4 hunks)
  • deepmd/jax/utils/exclude_mask.py (3 hunks)
  • deepmd/jax/utils/network.py (4 hunks)
  • deepmd/jax/utils/type_embed.py (2 hunks)
🧰 Additional context used
🧠 Learnings (4)
📚 Learning: 2024-10-26T02:09:01.365Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4258
File: deepmd/jax/utils/neighbor_stat.py:98-101
Timestamp: 2024-10-26T02:09:01.365Z
Learning: The function `to_jax_array` in `deepmd/jax/common.py` can handle `None` values, so it's safe to pass `None` to it without additional checks.

Applied to files:

  • deepmd/jax/atomic_model/linear_atomic_model.py
  • deepmd/jax/atomic_model/pairtab_atomic_model.py
  • deepmd/jax/descriptor/se_t_tebd.py
  • deepmd/jax/descriptor/dpa1.py
  • deepmd/jax/descriptor/dpa3.py
  • deepmd/jax/fitting/fitting.py
  • deepmd/jax/descriptor/dpa2.py
  • deepmd/jax/descriptor/repflows.py
  • deepmd/jax/atomic_model/base_atomic_model.py
  • deepmd/jax/descriptor/se_t.py
  • deepmd/jax/descriptor/repformers.py
  • deepmd/jax/utils/network.py
  • deepmd/jax/descriptor/se_e2_r.py
  • deepmd/jax/descriptor/se_e2_a.py
📚 Learning: 2024-10-30T20:08:12.531Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4284
File: deepmd/jax/__init__.py:8-8
Timestamp: 2024-10-30T20:08:12.531Z
Learning: In the DeepMD project, entry points like `deepmd.jax` may be registered in external projects, so their absence in the local configuration files is acceptable.

Applied to files:

  • deepmd/jax/atomic_model/linear_atomic_model.py
  • deepmd/jax/atomic_model/pairtab_atomic_model.py
  • deepmd/jax/descriptor/se_t_tebd.py
  • deepmd/jax/descriptor/dpa1.py
  • deepmd/jax/fitting/fitting.py
  • deepmd/jax/env.py
  • deepmd/jax/descriptor/dpa2.py
  • deepmd/jax/descriptor/repflows.py
  • deepmd/jax/atomic_model/base_atomic_model.py
  • deepmd/jax/descriptor/se_t.py
  • deepmd/jax/descriptor/repformers.py
  • deepmd/jax/utils/network.py
  • deepmd/jax/descriptor/se_e2_r.py
  • deepmd/jax/descriptor/se_e2_a.py
📚 Learning: 2024-11-23T00:01:06.984Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4406
File: deepmd/dpmodel/array_api.py:51-53
Timestamp: 2024-11-23T00:01:06.984Z
Learning: In `deepmd/dpmodel/array_api.py`, the `__array_api_version__` attribute is guaranteed by the Array API standard to always be present, so error handling for its absence is not required.

Applied to files:

  • deepmd/jax/descriptor/dpa1.py
  • deepmd/jax/descriptor/dpa3.py
  • deepmd/jax/atomic_model/base_atomic_model.py
  • deepmd/jax/descriptor/se_t.py
  • deepmd/jax/descriptor/hybrid.py
  • deepmd/jax/descriptor/repformers.py
  • deepmd/jax/utils/network.py
  • deepmd/jax/descriptor/se_e2_r.py
  • deepmd/jax/descriptor/se_e2_a.py
📚 Learning: 2024-10-30T03:16:31.013Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4278
File: source/tests/array_api_strict/fitting/fitting.py:58-66
Timestamp: 2024-10-30T03:16:31.013Z
Learning: The attributes `scale` and `constant_matrix` are specific to the `PolarFittingNet` class and should have their special handling within this class rather than in the general `setattr_for_general_fitting` function.

Applied to files:

  • deepmd/jax/fitting/fitting.py
🧬 Code graph analysis (7)
deepmd/jax/atomic_model/linear_atomic_model.py (1)
deepmd/jax/atomic_model/pairtab_atomic_model.py (1)
  • PairTabAtomicModel (31-58)
deepmd/jax/utils/type_embed.py (2)
deepmd/dpmodel/utils/type_embed.py (1)
  • TypeEmbedNet (30-221)
deepmd/jax/common.py (5)
  • ArrayAPIVariable (86-97)
  • flax_module (45-83)
  • to_jax_array (20-20)
  • to_jax_array (24-24)
  • to_jax_array (27-42)
deepmd/jax/descriptor/dpa2.py (2)
deepmd/jax/descriptor/se_t_tebd.py (1)
  • DescrptBlockSeTTebd (38-55)
deepmd/jax/utils/type_embed.py (1)
  • TypeEmbedNet (26-36)
deepmd/jax/descriptor/repflows.py (1)
deepmd/jax/common.py (1)
  • ArrayAPIVariable (86-97)
deepmd/jax/descriptor/hybrid.py (2)
source/tests/consistent/descriptor/test_hybrid.py (1)
  • data (50-78)
deepmd/dpmodel/descriptor/hybrid.py (2)
  • deserialize (379-389)
  • serialize (370-376)
deepmd/jax/descriptor/se_e2_r.py (1)
deepmd/jax/utils/network.py (1)
  • NetworkCollection (78-88)
deepmd/jax/descriptor/se_e2_a.py (1)
deepmd/jax/utils/network.py (1)
  • NetworkCollection (78-88)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: CodeQL analysis (python)
  • GitHub Check: Agent
🔇 Additional comments (31)
deepmd/jax/utils/type_embed.py (2)

6-8: Imports look good.

The addition of version checking imports is appropriate for the Flax 0.12 compatibility changes.

Also applies to: 16-19


32-33: nnx.data(None) is valid and correct for Flax 0.12+

According to Flax 0.12 documentation, nnx.data(None) is the correct way to set an attribute's default value to None while registering it as pytree data. The pattern is explicitly documented in the Flax 0.12 release notes regarding stricter pytree attribute rules. Wrapping None is necessary because calling nnx.data() without arguments uses a Missing sentinel by default, whereas nnx.data(None) explicitly sets the value to None and includes it in the pytree. This usage is consistent with the codebase's approach and safe to use.

deepmd/jax/descriptor/dpa3.py (1)

6-8: LGTM - Consistent Flax 0.12 compatibility pattern.

The version-gated wrapping with nnx.data() follows the same pattern as other files in this PR. The implementation is consistent for handling mean and stddev attributes.

Also applies to: 22-25, 39-40

deepmd/jax/descriptor/se_t.py (1)

6-8: LGTM - Consistent Flax 0.12 compatibility pattern.

The version-gated wrapping is applied consistently for the dstd and davg attributes, matching the pattern across the PR.

Also applies to: 19-22, 41-42

deepmd/jax/descriptor/se_e2_a.py (2)

6-8: LGTM - Imports are consistent with the PR pattern.

Also applies to: 19-22


40-41: LGTM - Flax 0.12 compatibility applied to multiple attributes.

The version-gated wrapping is correctly applied to both dstd/davg and embeddings attributes, extending the pattern consistently.

Also applies to: 45-46

deepmd/jax/atomic_model/pairtab_atomic_model.py (2)

7-9: LGTM - Consistent Flax 0.12 compatibility for PairTab model.

The version-gated wrapping is correctly applied to tab_info and tab_data attributes, following the established pattern.

Also applies to: 22-26, 38-39


54-54: Clarify the purpose of stop_gradient on nlist.

Line 54 wraps nlist with jax.lax.stop_gradient(), which appears unrelated to the Flax 0.12 compatibility changes. This prevents gradient computation through the neighbor list.

Is this change:

  1. A separate bug fix?
  2. Required for Flax 0.12 compatibility?
  3. A performance optimization?

Please clarify the reason for this change, as it's not mentioned in the PR description and doesn't follow the pattern seen in other files.

deepmd/jax/descriptor/se_t_tebd.py (1)

6-8: LGTM - Consistent Flax 0.12 compatibility pattern.

The version-gated wrapping is applied consistently to mean and stddev attributes in the DescrptBlockSeTTebd class.

Also applies to: 22-25, 44-45

deepmd/jax/atomic_model/base_atomic_model.py (1)

6-8: LGTM - Flax 0.12 compatibility applied to base helper function.

The version-gated wrapping is correctly implemented in the base_atomic_model_set_attr helper function for out_bias and out_std attributes. Since this is a shared helper function (imported by pairtab_atomic_model.py at line 15), this ensures consistent behavior across atomic model implementations.

Also applies to: 14-17, 29-30

deepmd/jax/env.py (1)

8-8: Unable to complete full verification — repository cloning blocked by sandbox limitations.

However, based on the web search results I gathered:

Verified facts:

  • flax.__version__ does return a PEP 440-compatible string
  • nnx.data() is confirmed as a field descriptor in Flax 0.12+ for marking pytree "data" attributes ✓

What I cannot verify without code access:

  • How flax_version is actually used at line 23 in deepmd/jax/env.py
  • Whether the version parsing logic correctly handles packaging.version.Version(flax_version)
  • The specific usage pattern of nnx.data() wrapper in this PR

Recommendation:
The review comment's verification requests are technically sound. Both Flax APIs referenced exist and work as expected per official documentation. However, without seeing the actual implementation context, I cannot confirm whether the code correctly applies these APIs.


Final rewritten review comment:

Verify version parsing and nnx.data() usage in Flax 0.12 compatibility logic.

The import of flax.__version__ is correct. Flax maintains PEP 440-compatible version strings that work with packaging.version.Version(). However, please verify at line 23:

  1. That version comparison logic handles flax.__version__ correctly without exceptions
  2. That nnx.data() field descriptor is used correctly for Flax 0.12+ compatibility (as a field marker at assignment, not a wrapper function)

Also applies to: 23-23

deepmd/jax/descriptor/se_e2_r.py (1)

40-41: Verify that wrapping None values with nnx.data(None) is necessary for Flax 0.12.0.

The code wraps None values with nnx.data(None) when flax_version >= 0.12.0. This pattern appears multiple times across the codebase. Please confirm this is intentional and required by Flax 0.12.0's NNX module system, as wrapping None is unusual.

Also applies to: 45-46

deepmd/jax/utils/network.py (2)

63-66: Verify the wrapping pattern for layers attribute.

NativeNet wraps layers with nnx.List(value) without wrapping individual items, while NetworkCollection (lines 85-88) wraps each item with nnx.data(item) before wrapping the list. Please confirm whether the layers items should also be wrapped with nnx.data() for consistency, or if this difference is intentional based on the content type.


85-88: LGTM!

The wrapping pattern correctly applies nnx.data() to each item before wrapping the list with nnx.List() for Flax 0.12.0 compatibility.

deepmd/jax/atomic_model/linear_atomic_model.py (3)

40-43: LGTM!

The mapping_list wrapping correctly applies the pattern of wrapping each item with nnx.data() within nnx.List() for Flax 0.12.0 compatibility.


44-47: Verify that test behavior is unaffected by the zbl_weight no-op handling.

The zbl_weight attribute is now discarded (early return) instead of being wrapped, to avoid Flax trace mutation errors. Since the comment indicates this attribute is only used in tests, please confirm that all relevant tests still pass with this change.


48-54: LGTM!

The models deserialization and wrapping correctly applies the Flax 0.12.0 compatibility pattern.

deepmd/jax/descriptor/repformers.py (2)

42-61: LGTM!

The attribute wrapping for mean, stddev, and layers correctly implements the Flax 0.12.0 compatibility pattern, consistent with other descriptor modules.


98-133: LGTM!

The attribute wrapping in RepformerLayer correctly implements the Flax 0.12.0 compatibility pattern for both None values and lists, consistent with the broader codebase changes.

deepmd/jax/fitting/fitting.py (2)

39-57: LGTM!

The centralized attribute handling in setattr_for_general_fitting correctly implements the Flax 0.12.0 compatibility pattern and is appropriately reused across multiple fitting classes.


94-106: LGTM!

The PolarFittingNet correctly applies the Flax 0.12.0 compatibility pattern to its specific attributes (scale, constant_matrix) in addition to the general fitting attributes.

deepmd/jax/descriptor/hybrid.py (1)

28-38: LGTM!

The list wrapping for both nlist_cut_idx and descrpt_list correctly implements the Flax 0.12.0 compatibility pattern.

deepmd/jax/descriptor/dpa1.py (2)

63-71: LGTM!

The attention_layers list wrapping correctly implements the Flax 0.12.0 compatibility pattern.


75-96: LGTM!

The attribute wrapping in DescrptBlockSeAtten correctly implements the Flax 0.12.0 compatibility pattern for both mean/stddev and embeddings attributes.

deepmd/jax/descriptor/dpa2.py (1)

45-74: LGTM!

The attribute wrapping in DescrptDPA2.__setattr__ correctly implements the Flax 0.12.0 compatibility pattern. The handling of g1_shape_tranform and tebd_transform appropriately covers all cases: None values, NativeLayerDP, and IdentityDP instances.

deepmd/jax/utils/exclude_mask.py (2)

6-20: LGTM! Imports support version-conditional behavior.

The imports of Version, flax_version, and nnx are appropriate for implementing version-conditional wrapping logic.


23-44: The code pattern is correct; no changes needed.

The web search confirms that wrapping None with nnx.data(None) is the intended and correct pattern for flax 0.12.0. In flax 0.12.0, nnx.data(value) is the required way to mark attributes as pytree "data" (rather than static), even when the value is None. This prevents strict pytree attribute handling errors introduced in that version. The conditional logic is sound: when value is not None, use ArrayAPIVariable; otherwise, for flax >= 0.12.0, use nnx.data(None) to properly annotate optional/cached fields.

deepmd/jax/descriptor/repflows.py (4)

6-22: LGTM! Imports are consistent with version-conditional pattern.

The imports support the version-conditional wrapping logic needed for flax 0.12.0 compatibility.


34-39: Consistent pattern for None value handling.

The pattern matches the approach in exclude_mask.py for handling None values with flax 0.12.0.


71-74: Consistent None value handling.

The pattern for handling None values is consistent with the approach used throughout this PR.


75-78: Based on the web search results, I can now verify the concern. The search confirmed that:

  1. nnx.data() is the correct pattern for wrapping nnx.Variable instances - The documentation explicitly shows self.v = nnx.data(nnx.Param(jnp.ones(3))) as a valid pattern to wrap Variables and store them as pytree data.

  2. Flax 0.12.0 enforces stricter pytree handling - Containers with JAX Arrays must be marked with nnx.data() or use nnx.List/nnx.Dict for mixed data/static contents.

  3. The code pattern is correct - Wrapping ArrayAPIVariable instances (which extend nnx.Variable) with nnx.data() inside nnx.List is the proper and intended usage pattern for Flax 0.12.0+.

The original review comment raised a concern that turned out to be incorrect - the code is following the correct pattern.


The code pattern is correct; no issues found.

The wrapping of ArrayAPIVariable instances with nnx.data() inside nnx.List aligns with Flax 0.12.0's design for handling Variables as pytree data. This is the intended pattern documented in the Flax API and required for strict pytree compliance in 0.12.0+.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds compatibility with flax 0.12 by introducing version-specific handling for JAX modules. The changes wrap certain values with nnx.data() and nnx.List() when using flax version 0.12.0 or higher.

  • Exports flax_version from deepmd/jax/env.py for version checking
  • Adds conditional logic across multiple descriptor, fitting, and atomic model files to handle flax 0.12+ compatibility
  • Special handling for the zbl_weight attribute in linear atomic model

Reviewed changes

Copilot reviewed 18 out of 18 changed files in this pull request and generated 27 comments.

Show a summary per file
File Description
deepmd/jax/env.py Exports flax_version for version checking across modules
deepmd/jax/utils/type_embed.py Adds flax 0.12 compatibility for type embedding attributes
deepmd/jax/utils/network.py Wraps layers and network collections with nnx.List for flax 0.12+
deepmd/jax/utils/exclude_mask.py Adds flax 0.12 compatibility for mask attributes
deepmd/jax/fitting/fitting.py Handles fitting network attributes for flax 0.12+
deepmd/jax/descriptor/se_t_tebd.py Adds compatibility handling for SE-T-TEBD descriptor attributes
deepmd/jax/descriptor/se_t.py Adds compatibility handling for SE-T descriptor attributes
deepmd/jax/descriptor/se_e2_r.py Adds compatibility handling for SE-R descriptor attributes
deepmd/jax/descriptor/se_e2_a.py Adds compatibility handling for SE-A descriptor attributes
deepmd/jax/descriptor/repformers.py Extensive compatibility handling for Repformer layers and attributes
deepmd/jax/descriptor/repflows.py Adds compatibility handling for Repflow layers and attributes
deepmd/jax/descriptor/hybrid.py Wraps descriptor lists with nnx.List for flax 0.12+
deepmd/jax/descriptor/dpa3.py Adds compatibility handling for DPA3 descriptor attributes
deepmd/jax/descriptor/dpa2.py Adds compatibility handling for DPA2 descriptor attributes
deepmd/jax/descriptor/dpa1.py Adds compatibility handling for DPA1 attention layers and embeddings
deepmd/jax/atomic_model/pairtab_atomic_model.py Adds compatibility handling for pair table atomic model attributes
deepmd/jax/atomic_model/linear_atomic_model.py Special handling for zbl_weight and wraps model lists with nnx.List
deepmd/jax/atomic_model/base_atomic_model.py Adds compatibility handling for base atomic model attributes

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@wanghan-iapcm wanghan-iapcm added this pull request to the merge queue Dec 1, 2025
Merged via the queue into deepmodeling:devel with commit a72b3af Dec 1, 2025
66 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants