Skip to content

Conversation

@iProzd
Copy link
Collaborator

@iProzd iProzd commented Jul 28, 2025

Summary by CodeRabbit

  • New Features

    • Added an option to control whether output statistics are computed or loaded across atomic models.
  • Bug Fixes

    • More robust parameter transfer during fine‑tuning to handle renamed branches and missing pretrained keys.
  • Refactor

    • Revised output-statistics workflow and refined per‑type output bias application in composite models.
  • Tests

    • Simplified linear-model bias checks and added a ZBL finetuning test path.

iProzd added 2 commits July 28, 2025 14:53
Updated the logic for transferring state dict items to correctly handle keys related to ZBL models by replacing '.models.0.' with '.' and ensuring '.models.1.' items are retained. This improves compatibility when loading pretrained models with different model key structures.
Introduces a compute_out_stat parameter to compute_or_load_stat methods in BaseAtomicModel, DPAtomicModel, LinearEnergyAtomicModel, and PairTabAtomicModel. This allows conditional computation of output statistics, improving flexibility and control over the statistics computation process.
@iProzd iProzd marked this pull request as draft July 28, 2025 08:01
@iProzd iProzd requested a review from anyangml July 28, 2025 08:01
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jul 28, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

📝 Walkthrough

Walkthrough

Added an optional boolean parameter compute_or_load_out_stat (default True) to compute_or_load_stat across atomic models to allow skipping output-stat computation. Linear model output-stat handling was refactored (per-submodel output-stat removed, sampler wrapped to inject exclusions, apply_out_stat adjusted). Trainer finetune key-mapping was hardened for ZBL cases; tests updated.

Changes

Cohort / File(s) Change Summary
BaseAtomicModel: Signature Update
deepmd/pt/model/atomic_model/base_atomic_model.py
Added parameter compute_or_load_out_stat: bool = True to compute_or_load_stat; docstring updated to document input vs output stat handling; method body remains NotImplementedError.
DPAtomicModel: Conditional Output Stat
deepmd/pt/model/atomic_model/dp_atomic_model.py
compute_or_load_stat(..., compute_or_load_out_stat: bool = True) signature added; call to compute_or_load_out_stat is now conditional on the flag after input-stat computation.
PairTabAtomicModel: Conditional Output Stat
deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Renamed first param to sampled_func, added compute_or_load_out_stat: bool = True; docstring updated; call to compute_or_load_out_stat guarded by the flag and uses sampled_func.
LinearEnergyAtomicModel: Output Stat Refactor
deepmd/pt/model/atomic_model/linear_atomic_model.py
Removed per-submodel compute_or_load_out_stat. compute_or_load_stat(..., compute_or_load_out_stat: bool = True) now instructs sub-models to run compute_or_load_stat(..., compute_or_load_out_stat=False), builds a cached wrapped sampler that injects pair_exclude_types/atom_exclude_types, adjusts stat_file_path by type_map, calls output-stat on linear model, and apply_out_stat now adds per-atom-type output bias.
Trainer Fine-tune Parameter Logic
deepmd/pt/train/training.py
Reworked collect_single_finetune_params mapping: compute new_key per candidate, determine use_random_initialization, handle ZBL naming patterns (.models.0., .models.1.), copy from random or origin accordingly, and raise clear errors for missing pretrained keys. Internal logic only.
Tests: Linear Model Bias & ZBL Finetune Paths
source/tests/pt/model/test_linear_atomic_model_stat.py, source/tests/pt/test_training.py
Simplified linear-model bias test to assert full-model forward output equals unbiased output plus per-atom-type bias. Added test_zbl_from_standard: bool on DPTrainTest, wired ZBL-from-standard finetune path and comparisons of mapped state_dict keys, and enabled ZBL finetune branches in relevant test setups.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant LinearModel
    participant SubModel
    participant Sampler

    Caller->>LinearModel: compute_or_load_stat(sampled_func, stat_file_path, compute_or_load_out_stat)
    LinearModel->>SubModel: compute_or_load_stat(sampled_func, stat_file_path, compute_or_load_out_stat=False)
    Sampler->>LinearModel: samples (wrapped to inject pair/atom exclusion types)
    alt compute_or_load_out_stat == True
        LinearModel->>LinearModel: compute_or_load_out_stat(wrapped_sampler, stat_file_path)
        LinearModel->>LinearModel: apply_out_stat(per-atom-type biases)
    else
        Note over LinearModel: Skip computing/loading output stats
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~30–45 minutes

Suggested reviewers

  • njzjz
  • wanghan-iapcm
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

anyangml and others added 4 commits July 28, 2025 10:02
Enhanced Trainer to support fine-tuning ZBL models from standard models by handling key mapping and random state initialization. Added corresponding tests to verify ZBL fine-tuning behavior and ensure correct state dict transfer in test_training.py.
@caic99 caic99 requested a review from Copilot August 12, 2025 09:33
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for ZBL (Ziegler-Biersack-Littmark) model fine-tuning from standard models, allowing atomic models to transition between different model types during training.

  • Enhanced parameter handling during fine-tuning to support ZBL model initialization from standard models
  • Modified atomic model statistics computation to provide optional control over output statistics loading
  • Updated linear atomic model bias application logic to work directly with the model output

Reviewed Changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
source/tests/pt/test_training.py Added test configuration and logic for ZBL fine-tuning from standard models
source/tests/pt/model/test_linear_atomic_model_stat.py Simplified bias application verification in linear atomic model tests
deepmd/pt/train/training.py Enhanced fine-tuning parameter collection to handle ZBL model branches and key mapping
deepmd/pt/model/atomic_model/pairtab_atomic_model.py Added optional compute_out_stat parameter to control statistics computation
deepmd/pt/model/atomic_model/linear_atomic_model.py Refactored bias application and statistics handling for linear atomic models
deepmd/pt/model/atomic_model/dp_atomic_model.py Added conditional control for output statistics computation
deepmd/pt/model/atomic_model/base_atomic_model.py Added compute_out_stat parameter to base interface

@codecov
Copy link

codecov bot commented Aug 12, 2025

Codecov Report

❌ Patch coverage is 72.72727% with 9 lines in your changes missing coverage. Please review.
✅ Project coverage is 84.29%. Comparing base (cefce47) to head (404b915).
⚠️ Report is 70 commits behind head on devel.

Files with missing lines Patch % Lines
...eepmd/pt/model/atomic_model/linear_atomic_model.py 65.00% 7 Missing ⚠️
...epmd/pt/model/atomic_model/pairtab_atomic_model.py 50.00% 1 Missing ⚠️
deepmd/pt/train/training.py 88.88% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #4849      +/-   ##
==========================================
- Coverage   84.30%   84.29%   -0.02%     
==========================================
  Files         702      702              
  Lines       68620    68647      +27     
  Branches     3572     3573       +1     
==========================================
+ Hits        57850    57864      +14     
- Misses       9630     9641      +11     
- Partials     1140     1142       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@iProzd iProzd requested review from njzjz and wanghan-iapcm August 12, 2025 13:50
@iProzd iProzd marked this pull request as ready for review August 12, 2025 13:51
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
deepmd/pt/train/training.py (1)

513-532: Consider improving variable naming and error messages per reviewer feedback

The current implementation has two suggestions from previous reviews that remain valid:

  1. The variable name use_random_initialization is misleading since it controls whether to use random initialization OR pretrained weights
  2. The error message could be more descriptive to help users troubleshoot
🧹 Nitpick comments (1)
source/tests/pt/test_training.py (1)

255-266: Consider extracting common ZBL test setup to reduce duplication

The ZBL configuration setup is duplicated between test classes. Consider extracting this into a helper method or base class to improve maintainability.

+    def _setup_zbl_config(self):
+        """Helper method to setup ZBL configuration for testing."""
+        input_json_zbl = str(Path(__file__).parent / "water/zbl.json")
+        with open(input_json_zbl) as f:
+            self.config_zbl = json.load(f)
+        data_file = [str(Path(__file__).parent / "water/data/data_0")]
+        self.config_zbl["training"]["training_data"]["systems"] = data_file
+        self.config_zbl["training"]["validation_data"]["systems"] = data_file
+        self.config_zbl["model"] = deepcopy(model_zbl)
+        self.config_zbl["training"]["numb_steps"] = 1
+        self.config_zbl["training"]["save_freq"] = 1
+
     def setUp(self) -> None:
         input_json = str(Path(__file__).parent / "water/se_atten.json")
         with open(input_json) as f:
             self.config = json.load(f)
         data_file = [str(Path(__file__).parent / "water/data/data_0")]
         self.config["training"]["training_data"]["systems"] = data_file
         self.config["training"]["validation_data"]["systems"] = data_file
         self.config["model"] = deepcopy(model_dpa1)
         self.config["training"]["numb_steps"] = 1
         self.config["training"]["save_freq"] = 1

         self.test_zbl_from_standard = True
-
-        input_json_zbl = str(Path(__file__).parent / "water/zbl.json")
-        with open(input_json_zbl) as f:
-            self.config_zbl = json.load(f)
-        data_file = [str(Path(__file__).parent / "water/data/data_0")]
-        self.config_zbl["training"]["training_data"]["systems"] = data_file
-        self.config_zbl["training"]["validation_data"]["systems"] = data_file
-        self.config_zbl["model"] = deepcopy(model_zbl)
-        self.config_zbl["training"]["numb_steps"] = 1
-        self.config_zbl["training"]["save_freq"] = 1
+        self._setup_zbl_config()
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2fdf1c9 and 7770428.

📒 Files selected for processing (2)
  • deepmd/pt/train/training.py (1 hunks)
  • source/tests/pt/test_training.py (3 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
source/tests/pt/test_training.py (5)
deepmd/pt/utils/finetune.py (1)
  • get_finetune_rules (82-203)
deepmd/pt/entrypoints/main.py (1)
  • get_trainer (96-201)
deepmd/dpmodel/output_def.py (2)
  • wrapper (45-72)
  • wrapper (87-105)
deepmd/pd/train/wrapper.py (1)
  • state_dict (199-203)
deepmd/pt/train/training.py (1)
  • run (680-1170)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (6, 3.9)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Test Python (6, 3.12)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Analyze (python)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Test C++ (true)
🔇 Additional comments (2)
source/tests/pt/test_training.py (2)

33-34: LGTM! Well-structured test attribute for ZBL finetuning

The addition of test_zbl_from_standard as a class attribute is a clean way to enable ZBL finetuning tests selectively.


100-127: Test logic correctly verifies ZBL finetuning key mapping

The test implementation properly validates the key remapping logic for ZBL models:

  • Maps ".models.0." keys back to their original form for comparison
  • Skips ".models.1." keys as they should be randomly initialized
  • Excludes bias/std keys from comparison as expected

The test provides good coverage of the finetuning logic introduced in the training module.

@iProzd iProzd requested a review from njzjz August 13, 2025 05:26
iProzd added 2 commits August 13, 2025 13:26
Revised and clarified the docstrings for compute_or_load_stat in both BaseAtomicModel and PairTabAtomicModel to better describe the function parameters and behavior. Updated parameter names and descriptions for improved consistency and readability.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🔭 Outside diff range comments (1)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)

284-289: Propagate compute_or_load_out_stat through PT & PD wrappers

The public compute_or_load_stat in both make_model.py files still only accept (sampled_func, stat_file_path) and drop the new compute_or_load_out_stat flag. Update them to match your atomic‐model API:

• deepmd/pt/model/model/make_model.py
• deepmd/pd/model/model/make_model.py

Suggested diff (apply to both files at the compute_or_load_stat definition):

-        def compute_or_load_stat(
-            self,
-            sampled_func,
-            stat_file_path: Optional[DPPath] = None,
-        ):
-            """Compute or load the statistics."""
-            return self.atomic_model.compute_or_load_stat(sampled_func, stat_file_path)
+        def compute_or_load_stat(
+            self,
+            sampled_func,
+            stat_file_path: Optional[DPPath] = None,
+            compute_or_load_out_stat: bool = True,
+        ):
+            """Compute or load the statistics. Pass through `compute_or_load_out_stat`."""
+            return self.atomic_model.compute_or_load_stat(
+                sampled_func, stat_file_path, compute_or_load_out_stat
+            )

This ensures external callers can disable output‐stat computation via the public API.

♻️ Duplicate comments (2)
deepmd/pt/model/atomic_model/base_atomic_model.py (1)

369-386: Docstring update looks good and addresses prior feedback

The description now clearly differentiates input vs output statistics and documents the new flag. This resolves the prior request to update the docs.

deepmd/pt/model/atomic_model/linear_atomic_model.py (1)

2-2: Nit: import ordering (duplicate of prior nitpick)

There was a previous comment about the position of the functools import relative to typing. If you’re following a specific internal import order guideline, please align with it; otherwise, current ordering is acceptable.

🧹 Nitpick comments (7)
deepmd/pt/model/atomic_model/base_atomic_model.py (1)

362-367: Return type should be None (not NoReturn) to match subclass implementations

Subclasses (e.g., DPAtomicModel, LinearEnergyAtomicModel) return None. Annotating the base as NoReturn (a “never-returns” type) is incompatible with overriding methods that do return normally and may cause type-checker issues.

Apply this diff:

 def compute_or_load_stat(
     self,
     merged: Union[Callable[[], list[dict]], list[dict]],
     stat_file_path: Optional[DPPath] = None,
-    compute_or_load_out_stat: bool = True,
-) -> NoReturn:
+    compute_or_load_out_stat: bool = True,
+) -> None:

Note: If NoReturn becomes unused after this change, we can remove it from the typing imports in a follow-up.

deepmd/pt/model/atomic_model/dp_atomic_model.py (3)

291-297: Docstring nit: refer to sampled_func consistently

Use the same parameter name as the function signature to avoid confusion.

-When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update),
+When `sampled_func` is provided, all the statistics parameters will be calculated (or re-calculated for update),
-When `sampled` is not provided, it will check the existence of `stat_file_path`(s)
+When `sampled_func` is not provided, it will check the existence of `stat_file_path`(s)

302-304: Docstring nit: stat_file_path type is a path, not a dict

The doc currently says “The dictionary of paths…”. The type is Optional[DPPath].

-        stat_file_path
-            The dictionary of paths to the statistics files.
+        stat_file_path
+            The path to the statistics file(s).

313-325: Cache the sampler with an explicit bound

wrapped_sampler takes no arguments, so the cache will contain a single entry in practice. Still, explicitly setting maxsize=1 makes the intent clear and prevents accidental unbounded growth if arguments are added later.

-        @functools.lru_cache
+        @functools.lru_cache(maxsize=1)
         def wrapped_sampler():
deepmd/pt/model/atomic_model/linear_atomic_model.py (3)

321-326: Avoid duplication: delegate to base implementation of apply_out_stat

This override now matches BaseAtomicModel.apply_out_stat verbatim. Prefer delegating to reduce maintenance overhead and divergence risk.

     def apply_out_stat(
         self,
         ret: dict[str, torch.Tensor],
         atype: torch.Tensor,
     ):
         """Apply the stat to each atomic output.
@@
-        out_bias, out_std = self._fetch_out_stat(self.bias_keys)
-        for kk in self.bias_keys:
-            # nf x nloc x odims, out_bias: ntypes x odims
-            ret[kk] = ret[kk] + out_bias[kk][atype]
-        return ret
+        return super().apply_out_stat(ret, atype)

471-476: Docstring nit: refer to sampled_func consistently

Mirror the function signature to avoid confusion.

-When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update),
+When `sampled_func` is provided, all the statistics parameters will be calculated (or re-calculated for update),
-When `sampled` is not provided, it will check the existence of `stat_file_path`(s)
+When `sampled_func` is not provided, it will check the existence of `stat_file_path`(s)

489-494: Docstring nit: stat_file_path type is a path, not a dict

Align the description with the actual type: Optional[DPPath].

-        stat_file_path
-            The dictionary of paths to the statistics files.
+        stat_file_path
+            The path to the statistics file(s).
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7770428 and 2ef9444.

📒 Files selected for processing (4)
  • deepmd/pt/model/atomic_model/base_atomic_model.py (2 hunks)
  • deepmd/pt/model/atomic_model/dp_atomic_model.py (3 hunks)
  • deepmd/pt/model/atomic_model/linear_atomic_model.py (4 hunks)
  • deepmd/pt/model/atomic_model/pairtab_atomic_model.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/pt/model/atomic_model/pairtab_atomic_model.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
deepmd/pt/model/atomic_model/linear_atomic_model.py (5)
deepmd/pt/model/atomic_model/base_atomic_model.py (3)
  • _fetch_out_stat (576-592)
  • compute_or_load_out_stat (389-414)
  • compute_or_load_stat (362-387)
deepmd/pt/model/atomic_model/dp_atomic_model.py (2)
  • compute_or_load_stat (284-331)
  • wrapped_sampler (314-324)
deepmd/pd/model/model/make_model.py (1)
  • compute_or_load_stat (568-574)
deepmd/pt/model/model/make_model.py (1)
  • compute_or_load_stat (572-578)
deepmd/pt/utils/exclude_mask.py (2)
  • get_exclude_types (35-36)
  • get_exclude_types (101-102)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (26)
  • GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (6, 3.9)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Test Python (6, 3.12)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Analyze (c-cpp)
🔇 Additional comments (2)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)

284-307: Conditional computation of output stats is correct

The new compute_or_load_out_stat flag is properly surfaced and documented here, and the computation is guarded accordingly. This aligns with the PR objective.

deepmd/pt/model/atomic_model/linear_atomic_model.py (1)

495-499: Good: sub-models skip output-stat computation

Passing compute_or_load_out_stat=False to sub-models avoids double-biasing. This aligns with the new design.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🔭 Outside diff range comments (1)
deepmd/pt/model/atomic_model/base_atomic_model.py (1)

362-367: Harmonize compute_or_load_stat Signatures and Return Types

All overrides of compute_or_load_stat must match the base signature and return annotation. Change the base method’s return type from NoReturn to None and introduce the new compute_or_load_out_stat: bool = True parameter in every subclass override. No call-site changes are needed since the new parameter has a default.

Files requiring updates (add the third parameter and change -> NoReturn to -> None):

  • deepmd/pt/model/atomic_model/base_atomic_model.py (lines 362–367)
  • deepmd/pt/model/atomic_model/pairtab_atomic_model.py (lines 225–227)
  • deepmd/pt/model/atomic_model/linear_atomic_model.py (lines 471–473)
  • deepmd/pt/model/atomic_model/dp_atomic_model.py (lines 284–286)
  • deepmd/pt/model/model/model.py (lines 29–31)
  • deepmd/pt/model/model/spin_model.py (lines 344–346)
  • deepmd/pt/model/model/make_model.py (lines 572–574)
  • deepmd/pd/model/atomic_model/base_atomic_model.py (lines 365–367)
  • deepmd/pd/model/atomic_model/dp_atomic_model.py (lines 329–331)
  • deepmd/pd/model/model/model.py (lines 23–25)
  • deepmd/pd/model/model/make_model.py (lines 568–570)

Apply this snippet to each definition:

-    def compute_or_load_stat(
-        self,
-        …,
-        stat_file_path: Optional[DPPath] = None,
-) -> NoReturn:
+    def compute_or_load_stat(
+        self,
+        …,
+        stat_file_path: Optional[DPPath] = None,
+        compute_or_load_out_stat: bool = True,
+) -> None:

Optionally, mark the base method as @abstractmethod (imported from abc) for clarity.

♻️ Duplicate comments (1)
deepmd/pt/model/atomic_model/base_atomic_model.py (1)

369-385: Docstring mismatches: fix parameter names and path description

  • Use “merged” consistently instead of “sampled”.
  • stat_file_path is a single DPPath, not “dictionary of paths”.

Apply this doc-only diff:

-Compute or load the statistics parameters of the model,
-such as mean and standard deviation of descriptors or the energy bias of the fitting net.
-When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update),
-and saved in the `stat_file_path`(s).
-When `sampled` is not provided, it will check the existence of `stat_file_path`(s)
-and load the calculated statistics parameters.
+Compute or load the statistics parameters of the model,
+such as mean and standard deviation of descriptors or the energy bias of the fitting net.
+When `merged` is provided, all the statistics parameters will be calculated (or re-calculated for update),
+and saved to the given `stat_file_path`.
+When `merged` is not provided, it will check the existence of `stat_file_path`
+and load the calculated statistics parameters.
@@
-merged
-    The lazy sampled function to get data frames from different data systems.
+merged
+    The lazy sampled function to get data frames from different data systems.
@@
-stat_file_path
-    The dictionary of paths to the statistics files.
+stat_file_path
+    The path to the statistics file.
@@
-compute_or_load_out_stat : bool
+compute_or_load_out_stat : bool
     Whether to compute the output statistics.
     If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors).
🧹 Nitpick comments (1)
deepmd/pt/model/atomic_model/pairtab_atomic_model.py (1)

232-248: Docstring: align names and path description with the signature and types

  • Replace “sampled” with “merged” to match the parameter.
  • stat_file_path is a single DPPath, not a “dictionary of paths”.

Apply this doc-only diff:

-Compute or load the statistics parameters of the model,
-such as mean and standard deviation of descriptors or the energy bias of the fitting net.
-When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update),
-and saved in the `stat_file_path`(s).
-When `sampled` is not provided, it will check the existence of `stat_file_path`(s)
-and load the calculated statistics parameters.
+Compute or load the statistics parameters of the model,
+such as mean and standard deviation of descriptors or the energy bias of the fitting net.
+When `merged` is provided, all the statistics parameters will be calculated (or re-calculated for update),
+and saved to the given `stat_file_path`.
+When `merged` is not provided, it will check the existence of `stat_file_path`
+and load the calculated statistics parameters.
@@
-sampled_func
-    The lazy sampled function to get data frames from different data systems.
+merged
+    The lazy sampled function to get data frames from different data systems.
@@
-stat_file_path
-    The dictionary of paths to the statistics files.
+stat_file_path
+    The path to the statistics file.
@@
-compute_or_load_out_stat : bool
+compute_or_load_out_stat : bool
     Whether to compute the output statistics.
     If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors).
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2ef9444 and 404b915.

📒 Files selected for processing (2)
  • deepmd/pt/model/atomic_model/base_atomic_model.py (1 hunks)
  • deepmd/pt/model/atomic_model/pairtab_atomic_model.py (1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
deepmd/pt/model/atomic_model/pairtab_atomic_model.py (2)
deepmd/utils/path.py (1)
  • DPPath (27-157)
deepmd/pt/model/atomic_model/base_atomic_model.py (1)
  • compute_or_load_out_stat (389-414)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (20)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Test Python (6, 3.12)
  • GitHub Check: Test Python (6, 3.9)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Test C++ (true)
🔇 Additional comments (1)
deepmd/pt/model/atomic_model/pairtab_atomic_model.py (1)

250-251: LGTM: guarded invocation of output-stat computation

Conditionally computing output stats respects the new flag and is appropriate for PairTab (no input stats to handle here).

@iProzd iProzd added this pull request to the merge queue Aug 13, 2025
Merged via the queue into deepmodeling:devel with commit 1c29fe4 Aug 13, 2025
60 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants