-
Notifications
You must be signed in to change notification settings - Fork 582
feat(jax): export call_lower to SavedModel via jax2tf #4254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
📝 Walkthrough## Walkthrough
The pull request introduces several updates across multiple files, primarily enhancing the JAX backend's capabilities. Key changes include the addition of the ".savedmodel" suffix to the `JAXBackend` class, updates to the `deserialize_to_file` function to support TensorFlow's SavedModel format, and modifications to the `DeepEval` class for improved model loading based on file extensions. New functionalities are added in the `jax2tf` module, including a TensorFlow model wrapper class. Documentation is also updated to reflect these changes, particularly regarding supported file formats.
## Changes
| File Path | Change Summary |
|----------------------------------------------|---------------------------------------------------------------------------------------------------|
| `deepmd/backend/jax.py` | Updated `suffixes` in `JAXBackend` class to include ".savedmodel". |
| `deepmd/jax/utils/serialization.py` | Enhanced `deserialize_to_file` to handle ".savedmodel" files; generalized error handling. |
| `deepmd/jax/infer/deep_eval.py` | Added conditional logic to load models based on file extension; improved error handling. |
| `deepmd/jax/jax2tf/__init__.py` | Introduced a check for TensorFlow eager execution; raises `RuntimeError` if not in eager mode. |
| `deepmd/jax/jax2tf/serialization.py` | Added `deserialize_to_file` function for TensorFlow model files. |
| `deepmd/jax/jax2tf/tfmodel.py` | Introduced `TFModelWrapper` class for JAX and TensorFlow integration; added multiple methods. |
| `doc/backend.md` | Updated documentation to include ".savedmodel" as a supported file extension for JAX backend. |
| `source/tests/consistent/io/test_io.py` | Modified `test_deep_eval` to improve backend handling and clarity; no changes to functionality. |
| `.github/workflows/test_python.yml` | Added new job "Test TF2 eager mode" for targeted testing; refined caching strategy. |
| `source/tests/utils.py` | Introduced new variable `DP_TEST_TF2_ONLY` for environment variable checks. |
## Possibly related PRs
- **#3776**: The changes in the `.github/workflows/test_python.yml` file involve adjustments to the Python environment setup and package installations, which may relate to the overall integration of the JAX backend.
- **#4156**: This PR introduces a JAX backend implementation, which is directly related to the changes made in the main PR regarding the `JAXBackend` class.
- **#4236**: The modifications to the `JAXBackend` class and the introduction of serialization and deserialization functionalities are closely related to the changes in the main PR.
- **#4251**: The introduction of new methods in the `make_model` function that involve atomic calculations aligns with the changes made in the main PR regarding the handling of model outputs.
- **#4259**: The documentation updates for the JAX backend are relevant as they provide context and support for the changes made in the main PR.
- **#4278**: The enhancements to the `DipoleFitting` and `PolarFitting` classes to support array API compatibility are related to the main PR's focus on JAX integration.
- **#4294**: The modifications to the `DescrptDPA2` class to improve compatibility with array operations are directly relevant to the changes made in the main PR.
## Suggested labels
`breaking change`
## Suggested reviewers
- wanghan-iapcmThank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (3)
deepmd/backend/jax.py (1)
Line range hint
27-42: Document supported formats in class docstring.Consider enhancing the class documentation to explicitly mention the supported file formats. This would help users understand which formats they can use with the JAX backend.
class JAXBackend(Backend): - """JAX backend.""" + """JAX backend. + + Supports the following model formats: + - .jax: Native JAX format + - .savedmodel: TensorFlow SavedModel format + """deepmd/jax/utils/serialization.py (2)
58-64: Clarify the usage ofpolymorphic_shapesinjax2tf.convertThe
polymorphic_shapesparameter uses variables likenf,nloc, andnghost. If these are placeholders for dynamic dimensions, consider providing documentation or comments explaining their meanings. This will enhance code readability and maintainability for future developers.
Line range hint
88-89: Updateserialize_from_fileto handle.savedmodelfilesCurrently,
serialize_from_fileonly supports.jaxfiles and raises aValueErrorfor other formats. Sincedeserialize_to_filenow handles.savedmodelfiles, consider adding support for serializing from.savedmodelfiles for consistency. Alternatively, update the error message to reflect the supported formats.Example of updating the error message:
else: - raise ValueError("JAX backend only supports converting .jax directory") + raise ValueError("JAX backend only supports converting .jax and .savedmodel formats")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
- deepmd/backend/jax.py (1 hunks)
- deepmd/dpmodel/descriptor/se_e2_a.py (1 hunks)
- deepmd/jax/utils/serialization.py (1 hunks)
🔇 Additional comments (6)
deepmd/backend/jax.py (1)
41-41: Verify SavedModel serialization implementation.The addition of ".savedmodel" suffix aligns with the PR objective. Let's verify that the serialization hooks properly support this format.
✅ Verification successful
SavedModel serialization implementation verified
The JAX backend properly implements SavedModel format support:
- The backend correctly registers
.savedmodelsuffix inJAXBackendclass- The serialization implementation in
deepmd/jax/utils/serialization.pyincludes:
- Converting JAX model to TensorFlow using
jax2tf.convert- Saving as TensorFlow SavedModel format using
tf.saved_model.save- Proper shape specifications for the converted model
- Support for custom gradients
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify that serialization hooks handle SavedModel format # Test: Check if deserialize_to_file has SavedModel handling rg -A 10 "deserialize_to_file.*savedmodel" deepmd/jax/utils/serialization.py # Test: Check if serialize_from_file has SavedModel handling rg -A 10 "serialize_from_file.*savedmodel" deepmd/jax/utils/serialization.pyLength of output: 157
Script:
#!/bin/bash # Let's check for any SavedModel related code in serialization files rg -l "SavedModel|savedmodel" --type py # Check implementation details in serialization-related files rg -A 10 "serialize|deserialize" --type py # Look for any JAX to SavedModel conversion functions ast-grep --pattern 'def $_($$) { $$$ savedmodel $$$ }'Length of output: 456711
Script:
#!/bin/bash # Let's check for SavedModel related code in JAX backend rg -A 10 "SavedModel|savedmodel" deepmd/jax/utils/serialization.py deepmd/backend/jax.py # Check if there are any JAX to TF conversion functions rg -A 10 "jax2tf" deepmd/jax/utils/serialization.py deepmd/backend/jax.py # Look for any model export/save functions rg -A 10 "save.*model|export.*model" deepmd/jax/utils/serialization.py deepmd/backend/jax.pyLength of output: 5000
deepmd/dpmodel/descriptor/se_e2_a.py (2)
558-558: LGTM: Simplified array assignmentThe removal of unnecessary
xp.asarrayconversion is appropriate sinceself.sel_cumsumis already suitable for indexing operations.
Line range hint
580-599: Performance optimization: Replaced einsum with broadcastingThe replacement of
einsumoperations with equivalentsumand broadcasting operations is a good optimization. The changes maintain mathematical equivalence while potentially offering better performance on some hardware accelerators.Original:
gr_tmp = xp.einsum("lni,lnj->lij", gg, tr) grrg = xp.einsum("flid,fljd->flij", gr, gr1)New:
gr_tmp = xp.sum(gg[:, :, :, None] * tr[:, :, None, :], axis=1) grrg = xp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4)Let's verify the numerical equivalence of these operations:
✅ Verification successful
Optimization change is mathematically equivalent and safe to use
The replacement of einsum with broadcasting and sum operations is mathematically sound for the following reasons:
The original einsum operation
xp.einsum("lni,lnj->lij", gg, tr)is equivalent to the new broadcasting operationxp.sum(gg[:, :, :, None] * tr[:, :, None, :], axis=1)as they both perform the same matrix multiplication along the specified axes.Similarly,
xp.einsum("flid,fljd->flij", gr, gr1)is equivalent toxp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4), maintaining the same contraction pattern.The codebase shows consistent use of both einsum and broadcasting+sum patterns across different implementations (TensorFlow, PyTorch, and NumPy), indicating this is a well-understood transformation.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Search for any tests that verify the descriptor's output rg -l "test.*descriptor.*se.*a" --type pyLength of output: 85
Script:
#!/bin/bash # Let's examine the test file content to see if it verifies numerical equivalence cat source/tests/tf/test_descrpt_se_a_mask.pyLength of output: 8099
Script:
#!/bin/bash # Let's look for any other tests that might verify the descriptor output rg -l "test.*descriptor.*se.*[aA]" --type py # Also search for any numerical tests specifically related to einsum operations rg "einsum|sum.*broadcast" --type pyLength of output: 3078
deepmd/jax/utils/serialization.py (3)
50-51: Verify thatcall_loweris a valid attribute ofmodelThe code assigns
call_lower = model.call_lower. Please ensure that all models being deserialized have thecall_lowerattribute. If some models might not have this attribute, consider adding a check or handling potentialAttributeErrorexceptions to prevent runtime errors.
68-73: Ensure correct tensor shapes intf.TensorSpecThe
tf.TensorSpecdefinitions include dynamic dimensions (None) and a call tomodel.get_nnei(). Verify thatmodel.get_nnei()returns an integer and that the tensor shapes align with the expected input dimensions. Misalignment can lead to runtime errors when the SavedModel is used.
76-79: Review the necessity ofexperimental_custom_gradients=TrueThe option
experimental_custom_gradients=Trueis used intf.saved_model.SaveOptions. Confirm that custom gradients are required for your use case. If not, removing this option could simplify the code and avoid potential compatibility issues with future TensorFlow versions.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4254 +/- ##
==========================================
- Coverage 84.38% 84.33% -0.06%
==========================================
Files 563 570 +7
Lines 52810 53070 +260
Branches 3054 3054
==========================================
+ Hits 44564 44756 +192
- Misses 7287 7358 +71
+ Partials 959 956 -3 ☔ View full report in Codecov by Sentry. |
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Outside diff range and nitpick comments (15)
deepmd/jax/jax2tf/__init__.py (2)
4-11: Fix typo and enhance error message clarity.The runtime check for eager execution is crucial and well-placed. However, the error message could be improved.
if not tf.executing_eagerly(): # TF disallow temporary eager execution raise RuntimeError( - "Unfortunatly, jax2tf (requires eager execution) cannot be used with the " + "Unfortunately, jax2tf cannot be used with the TensorFlow backend. This is because " + "jax2tf requires TensorFlow to run in eager execution mode, which is disabled by " "TensorFlow backend (disables eager execution). " "If you are converting a model between different backends, " - "considering converting to the `.dp` format first." + "consider converting to the `.dp` format first." )
5-5: Improve comment clarity.The current comment is ambiguous and could be more descriptive.
- # TF disallow temporary eager execution + # The TensorFlow backend explicitly disables eager execution modedoc/backend.md (1)
33-33: Consider enhancing the SavedModel format documentation.While the current documentation explains the basic requirement, it would be beneficial to add:
- A brief example or link to example usage
- Any limitations or considerations when using the SavedModel format
Example addition:
`.savedmodel` is the TensorFlow [SavedModel format](https://www.tensorflow.org/guide/saved_model) generated by [JAX2TF](https://www.tensorflow.org/guide/jax2tf), which needs the installation of TensorFlow. +For usage examples and best practices, refer to the [JAX2TF conversion guide](https://www.tensorflow.org/guide/jax2tf).deepmd/jax/infer/deep_eval.py (3)
103-108: Consider adding SavedModel validation and error handling.While the implementation is functional, it could benefit from additional error handling to gracefully handle invalid SavedModel files or import failures.
Consider adding try-catch blocks:
elif model_file.endswith(".savedmodel"): from deepmd.jax.jax2tf.tfmodel import ( TFModelWrapper, ) - - self.dp = TFModelWrapper(model_file) + try: + self.dp = TFModelWrapper(model_file) + except Exception as e: + raise ValueError(f"Failed to load SavedModel from {model_file}: {str(e)}")
109-110: Enhance error message with supported file extensions.The current error message could be more helpful by explicitly listing the supported file extensions.
- raise ValueError("Unsupported file extension") + raise ValueError( + f"Unsupported file extension for {model_file}. " + "Supported extensions are: .hlo, .savedmodel" + )
93-110: Consider implementing a model loader factory pattern.The current implementation with multiple conditional blocks could be refactored to use a factory pattern, making it more maintainable and extensible for future model formats.
This would involve:
- Creating a ModelLoader interface/abstract class
- Implementing specific loaders for each format (HLOModelLoader, SavedModelLoader)
- Using a factory to create the appropriate loader based on file extension
Example structure:
from abc import ABC, abstractmethod class ModelLoader(ABC): @abstractmethod def load(self, model_file: str): pass class HLOModelLoader(ModelLoader): def load(self, model_file: str): model_data = load_dp_model(model_file) return HLO( stablehlo=model_data["@variables"]["stablehlo"].tobytes(), stablehlo_atomic_virial=model_data["@variables"]["stablehlo_atomic_virial"].tobytes(), model_def_script=model_data["model_def_script"], **model_data["constants"], ) class SavedModelLoader(ModelLoader): def load(self, model_file: str): return TFModelWrapper(model_file) class ModelLoaderFactory: @staticmethod def create_loader(model_file: str) -> ModelLoader: if model_file.endswith(".hlo"): return HLOModelLoader() elif model_file.endswith(".savedmodel"): return SavedModelLoader() raise ValueError(f"Unsupported file extension for {model_file}")This would make the code more maintainable and easier to extend with new model formats in the future.
deepmd/jax/jax2tf/serialization.py (3)
75-88: Use explicit data types for clarity and consistencyIn the creation of
tf.Variableinstances, using explicit data types liketf.float64instead oftf.doubleenhances readability and follows TensorFlow conventions.Update the data types as follows:
tf_model.rcut = tf.Variable(model.get_rcut(), dtype=tf.float64) ... tf_model.min_nbor_dist = tf.Variable( model.get_min_nbor_dist(), dtype=tf.float64 )
90-92: Optimize storage of JSON data intf.VariableStoring large JSON strings in a
tf.Variablemay not be efficient, especially if the data does not change during model execution. Consider using atf.constantor attaching the JSON data as an attribute instead.You could modify the code as follows:
-tf_model.model_def_script = tf.Variable( +tf_model.model_def_script = tf.constant( json.dumps(model_def_script, separators=(",", ":")), dtype=tf.string )
93-97: Handle exceptions during model savingWhen saving the model using
tf.saved_model.save, exceptions can occur due to issues like file permission errors or invalid model definitions. Wrapping the save operation in a try-except block can provide clearer error messages.Implement exception handling:
try: tf.saved_model.save( tf_model, model_file, options=tf.saved_model.SaveOptions(experimental_custom_gradients=True), ) +except Exception as e: + raise RuntimeError(f"Failed to save the model to '{model_file}': {e}")deepmd/jax/jax2tf/tfmodel.py (6)
178-178: Remove or clarify the commented-out error messageThe comment on line 178 appears to be an error message from a previous run:
# Attempt to convert a value (None) with an unsupported type (<class 'NoneType'>) to a Tensor.This could cause confusion for readers. If it's not needed, consider removing it. If it serves as a reminder or note, provide additional context.
274-277: Ensure consistency between 'get_nsel' and 'get_nnei' methodsBoth
get_nselandget_nneimethods seem to serve the same purpose—returning the number of selected neighboring atoms. Additionally, their docstrings are nearly identical. Consider consolidating these methods or clearly distinguishing their functionalities to avoid confusion.
75-109: Avoid duplicating docstrings between 'call' and 'call' methodsThe
__call__andcallmethods have identical docstrings. This duplication can lead to maintenance issues and inconsistencies in the future. Consider having a single comprehensive docstring in one method and referencing it in the other.For example, you can modify the
__call__method's docstring:def __call__( self, coord: jnp.ndarray, atype: jnp.ndarray, box: Optional[jnp.ndarray] = None, fparam: Optional[jnp.ndarray] = None, aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, ) -> Any: - """Return model prediction. - - Parameters - ---------- - ... [docstring contents] ... - """ + """Return model prediction by calling the `call` method.""" return self.call(coord, atype, box, fparam, aparam, do_atomic_virial)
197-204: Add type hints to getter methods for clarityThe methods
get_rcut,get_dim_fparam, andget_dim_aparamcurrently lack return type hints. Adding explicit return types enhances code readability and helps with static type checking.Apply this diff to add return type hints:
def get_rcut(self) -> float: """Get the cut-off radius.""" return self.rcut def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" return self.dim_fparam def get_dim_aparam(self) -> int: """Get the number (dimension) of atomic parameters of this atomic model.""" return self.dim_aparam
243-258: Implement the 'serialize' and 'deserialize' methods or raise NotImplementedError with guidanceThe
serializeanddeserializemethods currently raiseNotImplementedErrorwithout additional context. If these methods are intended for future implementation, consider adding a comment or docstring explaining when they will be implemented or what is expected.
19-20: Confirm the import of 'jnp' from 'deepmd.jax.env'You are importing
jnpfromdeepmd.jax.env:from deepmd.jax.env import ( jnp, )Ensure that this import provides the expected
jnp(usuallyjax.numpy). If there are no additional configurations or wrappers indeepmd.jax.env, you might consider importing directly:import jax.numpy as jnp
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (8)
deepmd/backend/jax.py(1 hunks)deepmd/jax/infer/deep_eval.py(1 hunks)deepmd/jax/jax2tf/__init__.py(1 hunks)deepmd/jax/jax2tf/serialization.py(1 hunks)deepmd/jax/jax2tf/tfmodel.py(1 hunks)deepmd/jax/utils/serialization.py(2 hunks)doc/backend.md(1 hunks)source/tests/consistent/io/test_io.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- deepmd/backend/jax.py
- deepmd/jax/utils/serialization.py
🔇 Additional comments (6)
doc/backend.md (1)
28-28: LGTM: Clear documentation of supported file extensions.
The addition of .savedmodel as a supported model filename extension is well-documented and aligns with the PR's objective of enabling JAX2TF export functionality.
source/tests/consistent/io/test_io.py (2)
155-158: LGTM: Model saving updated to use correct suffix index.
The changes correctly utilize the suffix index when saving model data, which aligns with the PR's objective of supporting SavedModel format via jax2tf.
Let's verify the JAX backend suffixes:
#!/bin/bash
# Search for JAXBackend suffixes definition
ast-grep --pattern 'class JAXBackend {
$$$
suffixes = $_
$$$
}'143-150: LGTM: Backend testing loop refactored for better suffix handling.
The refactoring improves clarity by explicitly handling suffix indices. The comment about jax2tf incompatibility with tf v1 behaviors provides valuable context.
Let's verify the jax2tf compatibility issue:
✅ Verification successful
Based on the search results, I can now verify the jax2tf compatibility issue with TF v1 behaviors. The evidence shows:
- In
deepmd/jax/jax2tf/__init__.py, there's an explicit error message:
"Unfortunatly, jax2tf (requires eager execution) cannot be used with the TensorFlow backend (disables eager execution)."
- The comment in the test file accurately reflects this technical limitation, as jax2tf requires eager execution mode while TF v1 behaviors disable it.
The comment about jax2tf incompatibility with TF v1 behaviors is accurate and well-documented.
The test case exclusion is justified by the architectural incompatibility between jax2tf's eager execution requirement and TF v1's graph mode execution.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for any related jax2tf compatibility checks or version constraints
rg -A 5 "jax2tf.*tf.*v1|tf.*v1.*jax2tf"
Length of output: 165
Script:
#!/bin/bash
# Search for jax2tf related code and comments
rg -g '!*.pyc' -g '!*.so' "jax2tf" -A 5
# Search for JAX backend suffix related code
ast-grep --pattern 'class JAXBackend {
$$$
suffixes = $_
$$$
}'
# Look for TF version compatibility checks
rg -g '!*.pyc' -g '!*.so' "tensorflow.*version|tf.*version" -A 3
Length of output: 55695
deepmd/jax/infer/deep_eval.py (1)
93-102: LGTM: HLO model loading implementation is well-structured.
The implementation correctly handles the loading and initialization of HLO models, with proper extraction of model components and binary data conversion.
deepmd/jax/jax2tf/tfmodel.py (2)
145-146: Verify that 'self.call_lower' is appropriately passed as a callable
In the call method, self.call_lower is passed to model_call_from_call_lower:
call_lower=self.call_lower,Ensure that self.call_lower is intended to be a callable function and that it conforms to the expected interface required by model_call_from_call_lower.
178-187: Handle potential issues when creating empty arrays for 'fparam' and 'aparam'
When fparam or aparam is None, empty arrays are created with shapes based on input dimensions. Ensure that self.get_dim_fparam() and self.get_dim_aparam() return valid dimensions to prevent issues with zero or negative sizes.
If additional validation is needed, consider adding checks:
dim_fparam = self.get_dim_fparam()
if dim_fparam <= 0:
raise ValueError("Invalid dimension for frame parameters.")
dim_aparam = self.get_dim_aparam()
if dim_aparam <= 0:
raise ValueError("Invalid dimension for atomic parameters.")Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
|
I found XLA (deps here) is too hard to build, so I still prefer JAX2TF for C++ deployment. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (5)
deepmd/jax/jax2tf/tfmodel.py (5)
178-178: Remove unnecessary error message commentThe comment on line 178 appears to be an error message from a previous run and may cause confusion. Consider removing it or replacing it with a more informative comment if needed.
271-273: Add docstring toget_selmethodThe method
get_sellacks a docstring. Adding a docstring will improve readability and help others understand its purpose.
278-280: Add docstring tomixed_typesmethodThe method
mixed_typesdoes not have a docstring. Including a docstring will enhance code clarity and maintainability.
254-255: Correct the return types in docstringsThe return types in the docstrings of the
deserialize(line 254) andget_model(line 322) methods mentionBaseModelandBaseBaseModel, respectively. Since these methods should return an instance ofTFModelWrapper, consider updating the docstrings to reflect the correct return type.Also applies to: 322-323
267-270: Consolidateget_nneiandget_nselmethodsBoth
get_nneiandget_nselmethods return the same value and have identical docstrings. To reduce redundancy, consider consolidating them into a single method or clarifying the distinction if they are intended to serve different purposes.Also applies to: 274-277
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (4)
deepmd/jax/jax2tf/tfmodel.py (4)
49-52: Add type annotations and docstring to the__init__method for clarityCurrently, the
__init__method lacks type annotations and a docstring. Adding these will improve code readability and help others understand how to use the class.Update the method to include type annotations and a docstring:
class TFModelWrapper(tf.Module): def __init__( - self, - model, + self, + model: str, ) -> None: + """ + Initialize the TFModelWrapper. + + Parameters + ---------- + model : str + Path to the SavedModel directory. + """ self.model = tf.saved_model.load(model)
75-109: Add return type annotation to the__call__method for clarityThe
__call__method currently lacks a return type annotation. Adding a return type will improve code readability and help users understand what the method returns.Update the method signature and import
Dictfromtyping:from typing import ( Any, Optional, + Dict, + List, ) def __call__( self, coord: jnp.ndarray, atype: jnp.ndarray, box: Optional[jnp.ndarray] = None, fparam: Optional[jnp.ndarray] = None, aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, -) -> Any: +) -> Dict[str, jnp.ndarray]: """ Return model prediction. ...
233-242: Implement or remove methods that raiseNotImplementedErrorThe methods
serialize,deserialize,update_sel, andget_modelcurrently raiseNotImplementedError. If these methods are intended to be used in the future, consider implementing them. If not, you may want to remove them to avoid confusion and potential runtime errors.Also applies to: 244-257, 281-307
48-325: Consider adding unit tests forTFModelWrapperSince
TFModelWrapperis a new class with significant functionality, adding unit tests would help ensure its correctness and maintainability.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (2)
deepmd/jax/jax2tf/serialization.py(1 hunks)deepmd/jax/jax2tf/tfmodel.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/jax/jax2tf/serialization.py
🔇 Additional comments (1)
deepmd/jax/jax2tf/tfmodel.py (1)
164-196: Handle potential None values for mapping in call_lower method
In the call_lower method, mapping is passed to the TensorFlow model. Ensure that the TensorFlow model can handle None values for mapping, or provide a default value if necessary to prevent runtime errors.
Run the following script to check if mapping being None is handled appropriately:
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
wanghan-iapcm
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before we have supported the savedmodel in deepEval, shall we write a specific test for it under tf2?
DeepEval has been supported in this PR. The test is added, but due to incompatibility between TF1 and TF2, they cannot be run in the same session. |
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (2)
source/tests/utils.py (1)
11-11: LGTM! Consider adding a descriptive comment.The implementation is clean and follows the existing pattern for environment variable checks. However, it would be helpful to add a comment explaining the purpose of this flag and when it should be set.
+# Skip TF1-incompatible tests when running with TensorFlow 2 in eager mode DP_TEST_TF2_ONLY = os.environ.get("DP_TEST_TF2_ONLY") == "1".github/workflows/test_python.yml (1)
56-61: LGTM! Consider adding documentation for test isolation.The addition of a separate job for TF2 eager mode testing is a good approach to handle the TF1/TF2 compatibility issues. The test isolation using
DP_TEST_TF2_ONLYflag is appropriate.Consider adding a comment explaining:
- Why this specific test needs isolation
- The reason for running only when
matrix.group == 1+ # Separate job for TF2 eager mode tests due to TF1/TF2 compatibility issues + # Only runs in group 1 to avoid redundant execution across test groups - name: Test TF2 eager mode
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
.github/workflows/test_python.yml(1 hunks)source/tests/consistent/io/test_io.py(3 hunks)source/tests/utils.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- source/tests/consistent/io/test_io.py
🔇 Additional comments (1)
.github/workflows/test_python.yml (1)
56-61: Verify test coverage for SavedModel functionality.
Let's ensure that the isolated test covers the new SavedModel export functionality.
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (1)
.github/workflows/test_python.yml (1)
60-65: Consider test execution strategy improvements.While the TF2 eager mode test is essential for validating SavedModel functionality, running it only in group 1 might delay the detection of issues. Additionally:
- Consider adding a comment explaining why this test needs isolation
- The test path could be moved to a variable for better maintainability
+ # TF1 and TF2 tests cannot run in the same session, so we isolate TF2 tests + - name: Test TF2 eager mode + env: + TEST_FILE: source/tests/consistent/io/test_io.py + NUM_WORKERS: 0 + DP_TEST_TF2_ONLY: 1 + run: pytest --cov=deepmd ${TEST_FILE} --durations=0
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
.github/workflows/test_python.yml(2 hunks)
🔇 Additional comments (2)
.github/workflows/test_python.yml (2)
41-44: LGTM! Environment configuration is well-structured.
The environment variables are correctly configured to:
- Simplify package sources
- Enable TensorFlow support in Horovod
- Disable PyTorch in Horovod (reduces complexity)
- Enable MPI support
28-33: Verify horovod installation reliability.
The removal of --only-binary=horovod flag and addition of --no-build-isolation might cause build issues in certain environments. While this change might be necessary for compatibility, it could lead to longer build times or compilation failures.
✅ Verification successful
Based on the git history, I can now provide a final response:
Horovod installation changes are intentional and safe
The changes to horovod installation flags are part of a deliberate evolution in the CI pipeline:
- Initially used
--only-binary=horovodwithhorovod[tensorflow-cpu] - Transitioned to separate tensorflow-cpu installation and simplified horovod installation
--no-build-isolationflag was added after testing and multiple iterations
The changes are well-tested through multiple commits and align with the project's move to use uv package installer and separate dependency management. The current approach has been stable in the CI pipeline.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Verify if horovod installation succeeds with these flags
uv pip install horovod --no-build-isolation 2>&1 | grep -i "error\|warning"
Length of output: 72
Script:
#!/bin/bash
# Let's check the git history to understand the context of these changes
git log -p -- .github/workflows/test_python.yml | grep -A 10 -B 10 "horovod"
# Also check if there are any related discussions or issues
git log --grep="horovod" --pretty=format:"%h %s"
# Look for any related configuration or build files
fd horovod
Length of output: 13203
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (2)
.github/workflows/test_python.yml (1)
60-65: Consider expanding TF2 eager mode test coverage.While the addition of TF2 eager mode testing is valuable, consider:
- The test only covers
test_io.py- should we include other JAX2TF-related tests?- Running only in group 1 is efficient but makes the test results less visible in the CI output.
Consider adding a comment explaining why this test is isolated to group 1 and whether there are plans to expand coverage.
pyproject.toml (1)
447-447: Consider separating MPI configuration changes.The addition of the OpenMPI source appears to be independent of the main PR objective (JAX backend and SavedModel export). Consider moving this change to a separate PR to maintain clear change boundaries and simplify review processes.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (2)
.github/workflows/test_python.yml(2 hunks)pyproject.toml(1 hunks)
🔇 Additional comments (2)
.github/workflows/test_python.yml (1)
41-44: LGTM! Environment variables properly configured for TensorFlow support.
The explicit Horovod configuration flags ensure proper build settings for TensorFlow support while disabling unnecessary PyTorch integration. This aligns well with the JAX2TF focus of the PR.
pyproject.toml (1)
447-447: LGTM! Consider documenting MPI implementation choice.
The addition of openmpi as an alternative source for mpi4py alongside the existing mpich source provides users with flexibility in choosing their MPI implementation. This is a good practice for supporting different environments and user preferences.
Let's verify the MPI configuration in the codebase:
Consider adding documentation to clarify:
- The differences between
mpichandopenmpioptions - Guidelines for choosing between the implementations
- Any known compatibility considerations
✅ Verification successful
Based on the search results, I can now generate the final response since I have sufficient information to verify the MPI configuration concern:
Both OpenMPI and MPICH implementations are supported and can coexist safely
The codebase is designed to work with different MPI implementations, as evidenced by:
- The build system properly detects and configures MPI settings at compile time
- Documentation explicitly mentions support for both MPICH and OpenMPI
- MPI-specific configurations are properly isolated and handled through CMake's find_package(MPI)
The addition of OpenMPI as an alternative source for mpi4py is safe and provides users with flexibility in choosing their preferred MPI implementation. The existing MPICH configuration remains unaffected.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for any MPI-specific configuration or requirements
# Test: Search for MPI-related configuration and documentation
rg -i "mpi" --type md
rg "MPI" -A 5 --type cmake
Length of output: 32054
Summary by CodeRabbit
Release Notes
New Features
Bug Fixes
Documentation
Tests