Skip to content

Conversation

@njzjz
Copy link
Member

@njzjz njzjz commented Oct 25, 2024

Summary by CodeRabbit

Release Notes

  • New Features

    • Added support for the TensorFlow SavedModel format, allowing users to handle additional model file types.
    • Introduced a new TensorFlow model wrapper class for enhanced integration with JAX functionalities.
  • Bug Fixes

    • Improved error handling for unsupported file formats during model deserialization.
  • Documentation

    • Updated backend documentation to reflect new file extensions and clarify backend capabilities.
  • Tests

    • Enhanced test structure for better clarity and maintainability regarding backend handling.
    • Added a new job for testing TensorFlow 2 in eager mode within the testing workflow.
    • Introduced a conditional skip for tests based on TensorFlow 2 compatibility.

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 25, 2024

📝 Walkthrough
## Walkthrough
The pull request introduces several updates across multiple files, primarily enhancing the JAX backend's capabilities. Key changes include the addition of the ".savedmodel" suffix to the `JAXBackend` class, updates to the `deserialize_to_file` function to support TensorFlow's SavedModel format, and modifications to the `DeepEval` class for improved model loading based on file extensions. New functionalities are added in the `jax2tf` module, including a TensorFlow model wrapper class. Documentation is also updated to reflect these changes, particularly regarding supported file formats.

## Changes

| File Path                                    | Change Summary                                                                                     |
|----------------------------------------------|---------------------------------------------------------------------------------------------------|
| `deepmd/backend/jax.py`                     | Updated `suffixes` in `JAXBackend` class to include ".savedmodel".                               |
| `deepmd/jax/utils/serialization.py`         | Enhanced `deserialize_to_file` to handle ".savedmodel" files; generalized error handling.        |
| `deepmd/jax/infer/deep_eval.py`            | Added conditional logic to load models based on file extension; improved error handling.          |
| `deepmd/jax/jax2tf/__init__.py`            | Introduced a check for TensorFlow eager execution; raises `RuntimeError` if not in eager mode.   |
| `deepmd/jax/jax2tf/serialization.py`        | Added `deserialize_to_file` function for TensorFlow model files.                                 |
| `deepmd/jax/jax2tf/tfmodel.py`             | Introduced `TFModelWrapper` class for JAX and TensorFlow integration; added multiple methods.     |
| `doc/backend.md`                            | Updated documentation to include ".savedmodel" as a supported file extension for JAX backend.    |
| `source/tests/consistent/io/test_io.py`    | Modified `test_deep_eval` to improve backend handling and clarity; no changes to functionality.   |
| `.github/workflows/test_python.yml`         | Added new job "Test TF2 eager mode" for targeted testing; refined caching strategy.               |
| `source/tests/utils.py`                     | Introduced new variable `DP_TEST_TF2_ONLY` for environment variable checks.                       |

## Possibly related PRs
- **#3776**: The changes in the `.github/workflows/test_python.yml` file involve adjustments to the Python environment setup and package installations, which may relate to the overall integration of the JAX backend.
- **#4156**: This PR introduces a JAX backend implementation, which is directly related to the changes made in the main PR regarding the `JAXBackend` class.
- **#4236**: The modifications to the `JAXBackend` class and the introduction of serialization and deserialization functionalities are closely related to the changes in the main PR.
- **#4251**: The introduction of new methods in the `make_model` function that involve atomic calculations aligns with the changes made in the main PR regarding the handling of model outputs.
- **#4259**: The documentation updates for the JAX backend are relevant as they provide context and support for the changes made in the main PR.
- **#4278**: The enhancements to the `DipoleFitting` and `PolarFitting` classes to support array API compatibility are related to the main PR's focus on JAX integration.
- **#4294**: The modifications to the `DescrptDPA2` class to improve compatibility with array operations are directly relevant to the changes made in the main PR.

## Suggested labels
`breaking change`

## Suggested reviewers
- wanghan-iapcm

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (3)
deepmd/backend/jax.py (1)

Line range hint 27-42: Document supported formats in class docstring.

Consider enhancing the class documentation to explicitly mention the supported file formats. This would help users understand which formats they can use with the JAX backend.

 class JAXBackend(Backend):
-    """JAX backend."""
+    """JAX backend.
+    
+    Supports the following model formats:
+    - .jax: Native JAX format
+    - .savedmodel: TensorFlow SavedModel format
+    """
deepmd/jax/utils/serialization.py (2)

58-64: Clarify the usage of polymorphic_shapes in jax2tf.convert

The polymorphic_shapes parameter uses variables like nf, nloc, and nghost. If these are placeholders for dynamic dimensions, consider providing documentation or comments explaining their meanings. This will enhance code readability and maintainability for future developers.


Line range hint 88-89: Update serialize_from_file to handle .savedmodel files

Currently, serialize_from_file only supports .jax files and raises a ValueError for other formats. Since deserialize_to_file now handles .savedmodel files, consider adding support for serializing from .savedmodel files for consistency. Alternatively, update the error message to reflect the supported formats.

Example of updating the error message:

 else:
-    raise ValueError("JAX backend only supports converting .jax directory")
+    raise ValueError("JAX backend only supports converting .jax and .savedmodel formats")
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 02580c2 and d46d5f0.

📒 Files selected for processing (3)
  • deepmd/backend/jax.py (1 hunks)
  • deepmd/dpmodel/descriptor/se_e2_a.py (1 hunks)
  • deepmd/jax/utils/serialization.py (1 hunks)
🔇 Additional comments (6)
deepmd/backend/jax.py (1)

41-41: Verify SavedModel serialization implementation.

The addition of ".savedmodel" suffix aligns with the PR objective. Let's verify that the serialization hooks properly support this format.

✅ Verification successful

SavedModel serialization implementation verified

The JAX backend properly implements SavedModel format support:

  1. The backend correctly registers .savedmodel suffix in JAXBackend class
  2. The serialization implementation in deepmd/jax/utils/serialization.py includes:
    • Converting JAX model to TensorFlow using jax2tf.convert
    • Saving as TensorFlow SavedModel format using tf.saved_model.save
    • Proper shape specifications for the converted model
    • Support for custom gradients
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify that serialization hooks handle SavedModel format

# Test: Check if deserialize_to_file has SavedModel handling
rg -A 10 "deserialize_to_file.*savedmodel" deepmd/jax/utils/serialization.py

# Test: Check if serialize_from_file has SavedModel handling
rg -A 10 "serialize_from_file.*savedmodel" deepmd/jax/utils/serialization.py

Length of output: 157


Script:

#!/bin/bash
# Let's check for any SavedModel related code in serialization files
rg -l "SavedModel|savedmodel" --type py

# Check implementation details in serialization-related files
rg -A 10 "serialize|deserialize" --type py

# Look for any JAX to SavedModel conversion functions
ast-grep --pattern 'def $_($$) {
  $$$
  savedmodel
  $$$
}'

Length of output: 456711


Script:

#!/bin/bash
# Let's check for SavedModel related code in JAX backend
rg -A 10 "SavedModel|savedmodel" deepmd/jax/utils/serialization.py deepmd/backend/jax.py

# Check if there are any JAX to TF conversion functions
rg -A 10 "jax2tf" deepmd/jax/utils/serialization.py deepmd/backend/jax.py

# Look for any model export/save functions
rg -A 10 "save.*model|export.*model" deepmd/jax/utils/serialization.py deepmd/backend/jax.py

Length of output: 5000

deepmd/dpmodel/descriptor/se_e2_a.py (2)

558-558: LGTM: Simplified array assignment

The removal of unnecessary xp.asarray conversion is appropriate since self.sel_cumsum is already suitable for indexing operations.


Line range hint 580-599: Performance optimization: Replaced einsum with broadcasting

The replacement of einsum operations with equivalent sum and broadcasting operations is a good optimization. The changes maintain mathematical equivalence while potentially offering better performance on some hardware accelerators.

Original:

gr_tmp = xp.einsum("lni,lnj->lij", gg, tr)
grrg = xp.einsum("flid,fljd->flij", gr, gr1)

New:

gr_tmp = xp.sum(gg[:, :, :, None] * tr[:, :, None, :], axis=1)
grrg = xp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4)

Let's verify the numerical equivalence of these operations:

✅ Verification successful

Optimization change is mathematically equivalent and safe to use

The replacement of einsum with broadcasting and sum operations is mathematically sound for the following reasons:

  1. The original einsum operation xp.einsum("lni,lnj->lij", gg, tr) is equivalent to the new broadcasting operation xp.sum(gg[:, :, :, None] * tr[:, :, None, :], axis=1) as they both perform the same matrix multiplication along the specified axes.

  2. Similarly, xp.einsum("flid,fljd->flij", gr, gr1) is equivalent to xp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4), maintaining the same contraction pattern.

  3. The codebase shows consistent use of both einsum and broadcasting+sum patterns across different implementations (TensorFlow, PyTorch, and NumPy), indicating this is a well-understood transformation.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for any tests that verify the descriptor's output
rg -l "test.*descriptor.*se.*a" --type py

Length of output: 85


Script:

#!/bin/bash
# Let's examine the test file content to see if it verifies numerical equivalence
cat source/tests/tf/test_descrpt_se_a_mask.py

Length of output: 8099


Script:

#!/bin/bash
# Let's look for any other tests that might verify the descriptor output
rg -l "test.*descriptor.*se.*[aA]" --type py

# Also search for any numerical tests specifically related to einsum operations
rg "einsum|sum.*broadcast" --type py

Length of output: 3078

deepmd/jax/utils/serialization.py (3)

50-51: Verify that call_lower is a valid attribute of model

The code assigns call_lower = model.call_lower. Please ensure that all models being deserialized have the call_lower attribute. If some models might not have this attribute, consider adding a check or handling potential AttributeError exceptions to prevent runtime errors.


68-73: Ensure correct tensor shapes in tf.TensorSpec

The tf.TensorSpec definitions include dynamic dimensions (None) and a call to model.get_nnei(). Verify that model.get_nnei() returns an integer and that the tensor shapes align with the expected input dimensions. Misalignment can lead to runtime errors when the SavedModel is used.


76-79: Review the necessity of experimental_custom_gradients=True

The option experimental_custom_gradients=True is used in tf.saved_model.SaveOptions. Confirm that custom gradients are required for your use case. If not, removing this option could simplify the code and avoid potential compatibility issues with future TensorFlow versions.

@codecov
Copy link

codecov bot commented Oct 25, 2024

Codecov Report

Attention: Patch coverage is 82.20859% with 29 lines in your changes missing coverage. Please review.

Project coverage is 84.33%. Comparing base (bfbe2ed) to head (9a94118).
Report is 137 commits behind head on devel.

Files with missing lines Patch % Lines
deepmd/jax/jax2tf/serialization.py 79.41% 14 Missing ⚠️
deepmd/jax/jax2tf/tfmodel.py 85.00% 12 Missing ⚠️
deepmd/jax/infer/deep_eval.py 85.71% 1 Missing ⚠️
deepmd/jax/jax2tf/__init__.py 66.66% 1 Missing ⚠️
deepmd/jax/utils/serialization.py 75.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #4254      +/-   ##
==========================================
- Coverage   84.38%   84.33%   -0.06%     
==========================================
  Files         563      570       +7     
  Lines       52810    53070     +260     
  Branches     3054     3054              
==========================================
+ Hits        44564    44756     +192     
- Misses       7287     7358      +71     
+ Partials      959      956       -3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

njzjz added 7 commits November 2, 2024 20:52
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
@github-actions github-actions bot added the Docs label Nov 3, 2024
@njzjz njzjz marked this pull request as ready for review November 3, 2024 09:44
@njzjz njzjz requested a review from wanghan-iapcm November 3, 2024 09:44
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Outside diff range and nitpick comments (15)
deepmd/jax/jax2tf/__init__.py (2)

4-11: Fix typo and enhance error message clarity.

The runtime check for eager execution is crucial and well-placed. However, the error message could be improved.

 if not tf.executing_eagerly():
     # TF disallow temporary eager execution
     raise RuntimeError(
-        "Unfortunatly, jax2tf (requires eager execution) cannot be used with the "
+        "Unfortunately, jax2tf cannot be used with the TensorFlow backend. This is because "
+        "jax2tf requires TensorFlow to run in eager execution mode, which is disabled by "
         "TensorFlow backend (disables eager execution). "
         "If you are converting a model between different backends, "
-        "considering converting to the `.dp` format first."
+        "consider converting to the `.dp` format first."
     )

5-5: Improve comment clarity.

The current comment is ambiguous and could be more descriptive.

-    # TF disallow temporary eager execution
+    # The TensorFlow backend explicitly disables eager execution mode
doc/backend.md (1)

33-33: Consider enhancing the SavedModel format documentation.

While the current documentation explains the basic requirement, it would be beneficial to add:

  1. A brief example or link to example usage
  2. Any limitations or considerations when using the SavedModel format

Example addition:

 `.savedmodel` is the TensorFlow [SavedModel format](https://www.tensorflow.org/guide/saved_model) generated by [JAX2TF](https://www.tensorflow.org/guide/jax2tf), which needs the installation of TensorFlow.
+For usage examples and best practices, refer to the [JAX2TF conversion guide](https://www.tensorflow.org/guide/jax2tf).
deepmd/jax/infer/deep_eval.py (3)

103-108: Consider adding SavedModel validation and error handling.

While the implementation is functional, it could benefit from additional error handling to gracefully handle invalid SavedModel files or import failures.

Consider adding try-catch blocks:

 elif model_file.endswith(".savedmodel"):
     from deepmd.jax.jax2tf.tfmodel import (
         TFModelWrapper,
     )
-
-    self.dp = TFModelWrapper(model_file)
+    try:
+        self.dp = TFModelWrapper(model_file)
+    except Exception as e:
+        raise ValueError(f"Failed to load SavedModel from {model_file}: {str(e)}")

109-110: Enhance error message with supported file extensions.

The current error message could be more helpful by explicitly listing the supported file extensions.

-            raise ValueError("Unsupported file extension")
+            raise ValueError(
+                f"Unsupported file extension for {model_file}. "
+                "Supported extensions are: .hlo, .savedmodel"
+            )

93-110: Consider implementing a model loader factory pattern.

The current implementation with multiple conditional blocks could be refactored to use a factory pattern, making it more maintainable and extensible for future model formats.

This would involve:

  1. Creating a ModelLoader interface/abstract class
  2. Implementing specific loaders for each format (HLOModelLoader, SavedModelLoader)
  3. Using a factory to create the appropriate loader based on file extension

Example structure:

from abc import ABC, abstractmethod

class ModelLoader(ABC):
    @abstractmethod
    def load(self, model_file: str):
        pass

class HLOModelLoader(ModelLoader):
    def load(self, model_file: str):
        model_data = load_dp_model(model_file)
        return HLO(
            stablehlo=model_data["@variables"]["stablehlo"].tobytes(),
            stablehlo_atomic_virial=model_data["@variables"]["stablehlo_atomic_virial"].tobytes(),
            model_def_script=model_data["model_def_script"],
            **model_data["constants"],
        )

class SavedModelLoader(ModelLoader):
    def load(self, model_file: str):
        return TFModelWrapper(model_file)

class ModelLoaderFactory:
    @staticmethod
    def create_loader(model_file: str) -> ModelLoader:
        if model_file.endswith(".hlo"):
            return HLOModelLoader()
        elif model_file.endswith(".savedmodel"):
            return SavedModelLoader()
        raise ValueError(f"Unsupported file extension for {model_file}")

This would make the code more maintainable and easier to extend with new model formats in the future.

deepmd/jax/jax2tf/serialization.py (3)

75-88: Use explicit data types for clarity and consistency

In the creation of tf.Variable instances, using explicit data types like tf.float64 instead of tf.double enhances readability and follows TensorFlow conventions.

Update the data types as follows:

 tf_model.rcut = tf.Variable(model.get_rcut(), dtype=tf.float64)
 ...
 tf_model.min_nbor_dist = tf.Variable(
     model.get_min_nbor_dist(), dtype=tf.float64
 )

90-92: Optimize storage of JSON data in tf.Variable

Storing large JSON strings in a tf.Variable may not be efficient, especially if the data does not change during model execution. Consider using a tf.constant or attaching the JSON data as an attribute instead.

You could modify the code as follows:

-tf_model.model_def_script = tf.Variable(
+tf_model.model_def_script = tf.constant(
     json.dumps(model_def_script, separators=(",", ":")), dtype=tf.string
 )

93-97: Handle exceptions during model saving

When saving the model using tf.saved_model.save, exceptions can occur due to issues like file permission errors or invalid model definitions. Wrapping the save operation in a try-except block can provide clearer error messages.

Implement exception handling:

 try:
     tf.saved_model.save(
         tf_model,
         model_file,
         options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
     )
+except Exception as e:
+    raise RuntimeError(f"Failed to save the model to '{model_file}': {e}")
deepmd/jax/jax2tf/tfmodel.py (6)

178-178: Remove or clarify the commented-out error message

The comment on line 178 appears to be an error message from a previous run:

# Attempt to convert a value (None) with an unsupported type (<class 'NoneType'>) to a Tensor.

This could cause confusion for readers. If it's not needed, consider removing it. If it serves as a reminder or note, provide additional context.


274-277: Ensure consistency between 'get_nsel' and 'get_nnei' methods

Both get_nsel and get_nnei methods seem to serve the same purpose—returning the number of selected neighboring atoms. Additionally, their docstrings are nearly identical. Consider consolidating these methods or clearly distinguishing their functionalities to avoid confusion.


75-109: Avoid duplicating docstrings between 'call' and 'call' methods

The __call__ and call methods have identical docstrings. This duplication can lead to maintenance issues and inconsistencies in the future. Consider having a single comprehensive docstring in one method and referencing it in the other.

For example, you can modify the __call__ method's docstring:

 def __call__(
     self,
     coord: jnp.ndarray,
     atype: jnp.ndarray,
     box: Optional[jnp.ndarray] = None,
     fparam: Optional[jnp.ndarray] = None,
     aparam: Optional[jnp.ndarray] = None,
     do_atomic_virial: bool = False,
 ) -> Any:
-    """Return model prediction.
-
-    Parameters
-    ----------
-    ... [docstring contents] ...
-    """
+    """Return model prediction by calling the `call` method."""
     return self.call(coord, atype, box, fparam, aparam, do_atomic_virial)

197-204: Add type hints to getter methods for clarity

The methods get_rcut, get_dim_fparam, and get_dim_aparam currently lack return type hints. Adding explicit return types enhances code readability and helps with static type checking.

Apply this diff to add return type hints:

 def get_rcut(self) -> float:
     """Get the cut-off radius."""
     return self.rcut

 def get_dim_fparam(self) -> int:
     """Get the number (dimension) of frame parameters of this atomic model."""
     return self.dim_fparam

 def get_dim_aparam(self) -> int:
     """Get the number (dimension) of atomic parameters of this atomic model."""
     return self.dim_aparam

243-258: Implement the 'serialize' and 'deserialize' methods or raise NotImplementedError with guidance

The serialize and deserialize methods currently raise NotImplementedError without additional context. If these methods are intended for future implementation, consider adding a comment or docstring explaining when they will be implemented or what is expected.


19-20: Confirm the import of 'jnp' from 'deepmd.jax.env'

You are importing jnp from deepmd.jax.env:

from deepmd.jax.env import (
    jnp,
)

Ensure that this import provides the expected jnp (usually jax.numpy). If there are no additional configurations or wrappers in deepmd.jax.env, you might consider importing directly:

import jax.numpy as jnp
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between d46d5f0 and 5a7fc4a.

📒 Files selected for processing (8)
  • deepmd/backend/jax.py (1 hunks)
  • deepmd/jax/infer/deep_eval.py (1 hunks)
  • deepmd/jax/jax2tf/__init__.py (1 hunks)
  • deepmd/jax/jax2tf/serialization.py (1 hunks)
  • deepmd/jax/jax2tf/tfmodel.py (1 hunks)
  • deepmd/jax/utils/serialization.py (2 hunks)
  • doc/backend.md (1 hunks)
  • source/tests/consistent/io/test_io.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • deepmd/backend/jax.py
  • deepmd/jax/utils/serialization.py
🔇 Additional comments (6)
doc/backend.md (1)

28-28: LGTM: Clear documentation of supported file extensions.

The addition of .savedmodel as a supported model filename extension is well-documented and aligns with the PR's objective of enabling JAX2TF export functionality.

source/tests/consistent/io/test_io.py (2)

155-158: LGTM: Model saving updated to use correct suffix index.

The changes correctly utilize the suffix index when saving model data, which aligns with the PR's objective of supporting SavedModel format via jax2tf.

Let's verify the JAX backend suffixes:

#!/bin/bash
# Search for JAXBackend suffixes definition
ast-grep --pattern 'class JAXBackend {
  $$$
  suffixes = $_
  $$$
}'

143-150: LGTM: Backend testing loop refactored for better suffix handling.

The refactoring improves clarity by explicitly handling suffix indices. The comment about jax2tf incompatibility with tf v1 behaviors provides valuable context.

Let's verify the jax2tf compatibility issue:

✅ Verification successful

Based on the search results, I can now verify the jax2tf compatibility issue with TF v1 behaviors. The evidence shows:

  1. In deepmd/jax/jax2tf/__init__.py, there's an explicit error message:
"Unfortunatly, jax2tf (requires eager execution) cannot be used with the TensorFlow backend (disables eager execution)."
  1. The comment in the test file accurately reflects this technical limitation, as jax2tf requires eager execution mode while TF v1 behaviors disable it.

The comment about jax2tf incompatibility with TF v1 behaviors is accurate and well-documented.

The test case exclusion is justified by the architectural incompatibility between jax2tf's eager execution requirement and TF v1's graph mode execution.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for any related jax2tf compatibility checks or version constraints
rg -A 5 "jax2tf.*tf.*v1|tf.*v1.*jax2tf"

Length of output: 165


Script:

#!/bin/bash
# Search for jax2tf related code and comments
rg -g '!*.pyc' -g '!*.so' "jax2tf" -A 5

# Search for JAX backend suffix related code
ast-grep --pattern 'class JAXBackend {
  $$$
  suffixes = $_
  $$$
}'

# Look for TF version compatibility checks
rg -g '!*.pyc' -g '!*.so' "tensorflow.*version|tf.*version" -A 3

Length of output: 55695

deepmd/jax/infer/deep_eval.py (1)

93-102: LGTM: HLO model loading implementation is well-structured.

The implementation correctly handles the loading and initialization of HLO models, with proper extraction of model components and binary data conversion.

deepmd/jax/jax2tf/tfmodel.py (2)

145-146: Verify that 'self.call_lower' is appropriately passed as a callable

In the call method, self.call_lower is passed to model_call_from_call_lower:

call_lower=self.call_lower,

Ensure that self.call_lower is intended to be a callable function and that it conforms to the expected interface required by model_call_from_call_lower.


178-187: Handle potential issues when creating empty arrays for 'fparam' and 'aparam'

When fparam or aparam is None, empty arrays are created with shapes based on input dimensions. Ensure that self.get_dim_fparam() and self.get_dim_aparam() return valid dimensions to prevent issues with zero or negative sizes.

If additional validation is needed, consider adding checks:

dim_fparam = self.get_dim_fparam()
if dim_fparam <= 0:
    raise ValueError("Invalid dimension for frame parameters.")

dim_aparam = self.get_dim_aparam()
if dim_aparam <= 0:
    raise ValueError("Invalid dimension for atomic parameters.")

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
@njzjz
Copy link
Member Author

njzjz commented Nov 3, 2024

I found XLA (deps here) is too hard to build, so I still prefer JAX2TF for C++ deployment.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (5)
deepmd/jax/jax2tf/tfmodel.py (5)

178-178: Remove unnecessary error message comment

The comment on line 178 appears to be an error message from a previous run and may cause confusion. Consider removing it or replacing it with a more informative comment if needed.


271-273: Add docstring to get_sel method

The method get_sel lacks a docstring. Adding a docstring will improve readability and help others understand its purpose.


278-280: Add docstring to mixed_types method

The method mixed_types does not have a docstring. Including a docstring will enhance code clarity and maintainability.


254-255: Correct the return types in docstrings

The return types in the docstrings of the deserialize (line 254) and get_model (line 322) methods mention BaseModel and BaseBaseModel, respectively. Since these methods should return an instance of TFModelWrapper, consider updating the docstrings to reflect the correct return type.

Also applies to: 322-323


267-270: Consolidate get_nnei and get_nsel methods

Both get_nnei and get_nsel methods return the same value and have identical docstrings. To reduce redundancy, consider consolidating them into a single method or clarifying the distinction if they are intended to serve different purposes.

Also applies to: 274-277

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 5a7fc4a and 980b4a9.

📒 Files selected for processing (1)
  • deepmd/jax/jax2tf/tfmodel.py (1 hunks)

@njzjz njzjz marked this pull request as draft November 3, 2024 21:46
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (4)
deepmd/jax/jax2tf/tfmodel.py (4)

49-52: Add type annotations and docstring to the __init__ method for clarity

Currently, the __init__ method lacks type annotations and a docstring. Adding these will improve code readability and help others understand how to use the class.

Update the method to include type annotations and a docstring:

 class TFModelWrapper(tf.Module):
     def __init__(
-        self,
-        model,
+        self,
+        model: str,
     ) -> None:
+        """
+        Initialize the TFModelWrapper.
+
+        Parameters
+        ----------
+        model : str
+            Path to the SavedModel directory.
+        """
         self.model = tf.saved_model.load(model)

75-109: Add return type annotation to the __call__ method for clarity

The __call__ method currently lacks a return type annotation. Adding a return type will improve code readability and help users understand what the method returns.

Update the method signature and import Dict from typing:

 from typing import (
     Any,
     Optional,
+    Dict,
+    List,
 )

 def __call__(
     self,
     coord: jnp.ndarray,
     atype: jnp.ndarray,
     box: Optional[jnp.ndarray] = None,
     fparam: Optional[jnp.ndarray] = None,
     aparam: Optional[jnp.ndarray] = None,
     do_atomic_virial: bool = False,
-) -> Any:
+) -> Dict[str, jnp.ndarray]:
     """
     Return model prediction.
     ...

233-242: Implement or remove methods that raise NotImplementedError

The methods serialize, deserialize, update_sel, and get_model currently raise NotImplementedError. If these methods are intended to be used in the future, consider implementing them. If not, you may want to remove them to avoid confusion and potential runtime errors.

Also applies to: 244-257, 281-307


48-325: Consider adding unit tests for TFModelWrapper

Since TFModelWrapper is a new class with significant functionality, adding unit tests would help ensure its correctness and maintainability.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 980b4a9 and 8e216f5.

📒 Files selected for processing (2)
  • deepmd/jax/jax2tf/serialization.py (1 hunks)
  • deepmd/jax/jax2tf/tfmodel.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/jax/jax2tf/serialization.py
🔇 Additional comments (1)
deepmd/jax/jax2tf/tfmodel.py (1)

164-196: Handle potential None values for mapping in call_lower method

In the call_lower method, mapping is passed to the TensorFlow model. Ensure that the TensorFlow model can handle None values for mapping, or provide a default value if necessary to prevent runtime errors.

Run the following script to check if mapping being None is handled appropriately:

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
@njzjz njzjz marked this pull request as ready for review November 4, 2024 00:14
Copy link
Collaborator

@wanghan-iapcm wanghan-iapcm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before we have supported the savedmodel in deepEval, shall we write a specific test for it under tf2?

@njzjz
Copy link
Member Author

njzjz commented Nov 4, 2024

Before we have supported the savedmodel in deepEval, shall we write a specific test for it under tf2?

DeepEval has been supported in this PR. The test is added, but due to incompatibility between TF1 and TF2, they cannot be run in the same session.

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (2)
source/tests/utils.py (1)

11-11: LGTM! Consider adding a descriptive comment.

The implementation is clean and follows the existing pattern for environment variable checks. However, it would be helpful to add a comment explaining the purpose of this flag and when it should be set.

+# Skip TF1-incompatible tests when running with TensorFlow 2 in eager mode
 DP_TEST_TF2_ONLY = os.environ.get("DP_TEST_TF2_ONLY") == "1"
.github/workflows/test_python.yml (1)

56-61: LGTM! Consider adding documentation for test isolation.

The addition of a separate job for TF2 eager mode testing is a good approach to handle the TF1/TF2 compatibility issues. The test isolation using DP_TEST_TF2_ONLY flag is appropriate.

Consider adding a comment explaining:

  • Why this specific test needs isolation
  • The reason for running only when matrix.group == 1
+    # Separate job for TF2 eager mode tests due to TF1/TF2 compatibility issues
+    # Only runs in group 1 to avoid redundant execution across test groups
     - name: Test TF2 eager mode
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 3e919a1 and 60694a8.

📒 Files selected for processing (3)
  • .github/workflows/test_python.yml (1 hunks)
  • source/tests/consistent/io/test_io.py (3 hunks)
  • source/tests/utils.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • source/tests/consistent/io/test_io.py
🔇 Additional comments (1)
.github/workflows/test_python.yml (1)

56-61: Verify test coverage for SavedModel functionality.

Let's ensure that the isolated test covers the new SavedModel export functionality.

njzjz added 4 commits November 3, 2024 22:52
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (1)
.github/workflows/test_python.yml (1)

60-65: Consider test execution strategy improvements.

While the TF2 eager mode test is essential for validating SavedModel functionality, running it only in group 1 might delay the detection of issues. Additionally:

  1. Consider adding a comment explaining why this test needs isolation
  2. The test path could be moved to a variable for better maintainability
+    # TF1 and TF2 tests cannot run in the same session, so we isolate TF2 tests
+    - name: Test TF2 eager mode
+      env:
+        TEST_FILE: source/tests/consistent/io/test_io.py
+        NUM_WORKERS: 0
+        DP_TEST_TF2_ONLY: 1
+      run: pytest --cov=deepmd ${TEST_FILE} --durations=0
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 60694a8 and de0caa1.

📒 Files selected for processing (1)
  • .github/workflows/test_python.yml (2 hunks)
🔇 Additional comments (2)
.github/workflows/test_python.yml (2)

41-44: LGTM! Environment configuration is well-structured.

The environment variables are correctly configured to:

  • Simplify package sources
  • Enable TensorFlow support in Horovod
  • Disable PyTorch in Horovod (reduces complexity)
  • Enable MPI support

28-33: Verify horovod installation reliability.

The removal of --only-binary=horovod flag and addition of --no-build-isolation might cause build issues in certain environments. While this change might be necessary for compatibility, it could lead to longer build times or compilation failures.

✅ Verification successful

Based on the git history, I can now provide a final response:

Horovod installation changes are intentional and safe

The changes to horovod installation flags are part of a deliberate evolution in the CI pipeline:

  • Initially used --only-binary=horovod with horovod[tensorflow-cpu]
  • Transitioned to separate tensorflow-cpu installation and simplified horovod installation
  • --no-build-isolation flag was added after testing and multiple iterations

The changes are well-tested through multiple commits and align with the project's move to use uv package installer and separate dependency management. The current approach has been stable in the CI pipeline.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Verify if horovod installation succeeds with these flags
uv pip install horovod --no-build-isolation 2>&1 | grep -i "error\|warning"

Length of output: 72


Script:

#!/bin/bash
# Let's check the git history to understand the context of these changes
git log -p -- .github/workflows/test_python.yml | grep -A 10 -B 10 "horovod"

# Also check if there are any related discussions or issues
git log --grep="horovod" --pretty=format:"%h %s"

# Look for any related configuration or build files
fd horovod

Length of output: 13203

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (2)
.github/workflows/test_python.yml (1)

60-65: Consider expanding TF2 eager mode test coverage.

While the addition of TF2 eager mode testing is valuable, consider:

  1. The test only covers test_io.py - should we include other JAX2TF-related tests?
  2. Running only in group 1 is efficient but makes the test results less visible in the CI output.

Consider adding a comment explaining why this test is isolated to group 1 and whether there are plans to expand coverage.

pyproject.toml (1)

447-447: Consider separating MPI configuration changes.

The addition of the OpenMPI source appears to be independent of the main PR objective (JAX backend and SavedModel export). Consider moving this change to a separate PR to maintain clear change boundaries and simplify review processes.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 306eef3 and 9a94118.

📒 Files selected for processing (2)
  • .github/workflows/test_python.yml (2 hunks)
  • pyproject.toml (1 hunks)
🔇 Additional comments (2)
.github/workflows/test_python.yml (1)

41-44: LGTM! Environment variables properly configured for TensorFlow support.

The explicit Horovod configuration flags ensure proper build settings for TensorFlow support while disabling unnecessary PyTorch integration. This aligns well with the JAX2TF focus of the PR.

pyproject.toml (1)

447-447: LGTM! Consider documenting MPI implementation choice.

The addition of openmpi as an alternative source for mpi4py alongside the existing mpich source provides users with flexibility in choosing their MPI implementation. This is a good practice for supporting different environments and user preferences.

Let's verify the MPI configuration in the codebase:

Consider adding documentation to clarify:

  1. The differences between mpich and openmpi options
  2. Guidelines for choosing between the implementations
  3. Any known compatibility considerations
✅ Verification successful

Based on the search results, I can now generate the final response since I have sufficient information to verify the MPI configuration concern:

Both OpenMPI and MPICH implementations are supported and can coexist safely

The codebase is designed to work with different MPI implementations, as evidenced by:

  • The build system properly detects and configures MPI settings at compile time
  • Documentation explicitly mentions support for both MPICH and OpenMPI
  • MPI-specific configurations are properly isolated and handled through CMake's find_package(MPI)

The addition of OpenMPI as an alternative source for mpi4py is safe and provides users with flexibility in choosing their preferred MPI implementation. The existing MPICH configuration remains unaffected.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for any MPI-specific configuration or requirements

# Test: Search for MPI-related configuration and documentation
rg -i "mpi" --type md
rg "MPI" -A 5 --type cmake

Length of output: 32054

@njzjz njzjz added this pull request to the merge queue Nov 4, 2024
Merged via the queue into deepmodeling:devel with commit 38815b3 Nov 4, 2024
60 checks passed
@njzjz njzjz deleted the savedmodel branch November 4, 2024 22:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants