-
Notifications
You must be signed in to change notification settings - Fork 582
Add 4 pt descriptor compression #4227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…o devel pt_tabulate_fusion recommit by Yan.
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 8
🧹 Outside diff range and nitpick comments (1)
deepmd/tf/utils/tabulate.py (1)
37-37: Add type hints for class attributes.The class should define type hints for its attributes to improve code maintainability and IDE support.
Add type hints at the class level:
class DPTabulate(BaseTabulate): descrpt_type: str graph: tf.Graph graph_def: tf.GraphDef suffix: str functype: int activation_fn: Callable[[tf.Tensor], tf.Tensor]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
deepmd/pt/utils/tabulate.py(1 hunks)deepmd/tf/utils/tabulate.py(3 hunks)deepmd/utils/tabulate.py(1 hunks)
🧰 Additional context used
🪛 Ruff
deepmd/pt/utils/tabulate.py
54-54: Do not use mutable data structures for argument defaults
Replace with None; initialize within function
(B006)
55-55: Do not perform function call ActivationFn in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable
(B008)
deepmd/utils/tabulate.py
353-353: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks
(B019)
🔇 Additional comments (11)
deepmd/utils/tabulate.py (4)
19-65: LGTM! Well-structured class initialization.
The class is well-organized with clear documentation and proper use of abstract base class pattern.
353-368: Remove @lru_cache to prevent memory leaks.
The existing review comment about potential memory leaks with @lru_cache on instance methods is still valid.
🧰 Tools
🪛 Ruff
353-353: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks
(B019)
449-457:
Add validation for rmax and rmin parameters.
The method performs division by (rmax - rmin) without validating that they are different.
Add parameter validation:
def _spline5_switch(self, xx, rmin, rmax):
+ if rmax <= rmin:
+ raise ValueError("rmax must be greater than rmin")
if xx < rmin:
vv = 1
elif xx < rmax:Likely invalid or redundant comment.
429-448:
Add validation for min_nbor_dist parameter.
The method uses min_nbor_dist in divisions without validating that it's positive and non-zero.
Add parameter validation:
def _get_env_mat_range(self, min_nbor_dist):
+ if min_nbor_dist <= 0:
+ raise ValueError("min_nbor_dist must be positive")
"""Change the embedding net range to sw / min_nbor_dist."""
sw = self._spline5_switch(min_nbor_dist, self.rcut_smth, self.rcut)Likely invalid or redundant comment.
deepmd/tf/utils/tabulate.py (5)
76-82: LGTM: Proper initialization of the parent class.
The constructor correctly initializes the parent class with the required parameters.
84-84: Consider using a cached property for descriptor type.
The descrpt_type is computed only once during initialization and doesn't change. Using @cached_property would be more idiomatic and consistent with other properties in the class.
30-32: LGTM: Clean import of base class.
The import of BaseTabulate is clean and follows Python import conventions.
172-172:
Fix implicit return in _get_descrpt_type.
The method has an implicit return path when no conditions match, which could lead to runtime errors.
76-82: Verify descriptor type initialization.
The initialization of descrpt_type happens after the parent class initialization. Let's verify that the parent class doesn't depend on this attribute.
Also applies to: 84-84
✅ Verification successful
Let me gather more information about the parent class and its initialization to make a conclusive decision.
Let me try a different approach to find the parent class and its usage of descrpt_type.
Let me check the parent class (BaseTabulate) to see how it uses the descriptor.
Based on the code analysis, I can now provide the final response:
The initialization order is correct and safe. The parent class BaseTabulate only initializes shared attributes and sets descrpt_type to a default value "Base". The actual descriptor type is determined by the child class's _get_descrpt_type method, which is called after parent initialization. The parent class doesn't use descrpt_type during initialization, and its methods that use this attribute are only called later during the build method.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for any usage of descrpt_type in the parent class
rg -l "descrpt_type" | grep -v "tabulate.py$" | while read -r file; do
echo "=== $file ==="
rg "descrpt_type" "$file"
done
Length of output: 972
Script:
#!/bin/bash
# Find the parent class implementation and its initialization
ast-grep --pattern 'class $_ {
$$$
__init__($$$) {
$$$
}
$$$
}' deepmd/tf/utils/tabulate.py
Length of output: 105
Script:
#!/bin/bash
# First, let's see the content of tabulate.py to understand the inheritance
cat deepmd/tf/utils/tabulate.py
# Then search for any usage of descrpt_type in the file
rg -A 5 -B 5 "descrpt_type" deepmd/tf/utils/tabulate.py
Length of output: 23892
Script:
#!/bin/bash
# Check the BaseTabulate class implementation
cat deepmd/utils/tabulate.py
Length of output: 16199
deepmd/pt/utils/tabulate.py (2)
515-516:
Improve error handling in grad_grad function
The function returns -1 for unsupported function types, while the grad function raises a ValueError. This inconsistency could lead to silent failures.
Make error handling consistent:
elif functype == 6:
return y * (1 - y) * (1 - 2 * y)
- else:
- return -1
+ raise ValueError(f"Unsupported function type: {functype}")Likely invalid or redundant comment.
321-321:
Add division by zero check in layer size calculation
The calculation basic_size // (self.ntypes - self._n_all_excluded) could raise a ZeroDivisionError if all types are excluded.
Add a check:
if self.type_one_side:
+ denominator = self.ntypes - self._n_all_excluded
+ if denominator == 0:
+ raise ValueError("Cannot calculate layer size: all types are excluded")
- layer_size = basic_size // (self.ntypes - self._n_all_excluded)
+ layer_size = basic_size // denominatorLikely invalid or redundant comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
🧹 Outside diff range and nitpick comments (4)
deepmd/utils/tabulate.py (3)
22-65: Enhance constructor documentation with type hints and attribute descriptions.The constructor's documentation could be improved by:
- Adding type hints for all parameters
- Documenting the purpose of each parameter
- Documenting the attributes that must be initialized in subclasses
Apply this diff to improve the documentation:
def __init__( self, - descrpt, - neuron, - type_one_side, - exclude_types, - is_pt, + descrpt: Any, # TODO: Add specific type + neuron: list[int], + type_one_side: bool, + exclude_types: set[tuple[int, int]], + is_pt: bool, ) -> None: - """Constructor.""" + """Initialize the base tabulate class. + + Parameters + ---------- + descrpt : Any + The descriptor object + neuron : list[int] + List of neurons in each layer + type_one_side : bool + Whether to use one-sided type + exclude_types : set[tuple[int, int]] + Set of type pairs to exclude + is_pt : bool + Whether this is a PyTorch implementation + + Notes + ----- + The following attributes must be initialized in subclasses: + - descrpt_type: str + - sel_a: list + - rcut: float + - rcut_smth: float + - davg: np.ndarray + - dstd: np.ndarray + - ntypes: int + """
336-423: Enhance abstract method documentation with complete type hints.The abstract methods would benefit from more detailed documentation and complete type hints.
Example improvement for
_get_descrpt_type:@abstractmethod - def _get_descrpt_type(self): - """Get the descrpt type.""" + def _get_descrpt_type(self) -> str: + """Get the descriptor type. + + Returns + ------- + str + The type of descriptor. Must be one of: + - "Atten" + - "A" + - "T" + - "R" + - "AEbdV2" + """ pass🧰 Tools
🪛 Ruff
354-354: Use of
functools.lru_cacheorfunctools.cacheon methods can lead to memory leaks(B019)
1-458: Add unit tests for mathematical operations.The file contains complex mathematical operations, particularly in the
buildand_build_lowermethods. Consider adding unit tests to verify:
- Correct calculation of spline coefficients
- Proper handling of boundary conditions
- Accuracy of tabulation results
Would you like me to help generate comprehensive unit tests for these mathematical operations?
🧰 Tools
🪛 Ruff
354-354: Use of
functools.lru_cacheorfunctools.cacheon methods can lead to memory leaks(B019)
deepmd/pt/utils/tabulate.py (1)
81-89: Moveactivation_mapto a module-level constantThe
activation_mapdictionary is defined inside the__init__method. Since it does not depend on any instance-specific data, defining it at the module level can improve code clarity and prevent it from being recreated with each instance.You can move
activation_mapoutside the class definition:# Module-level constant ACTIVATION_MAP = { "tanh": 1, "gelu": 2, "gelu_tf": 2, "relu": 3, "relu6": 4, "softplus": 5, "sigmoid": 6, } class DPTabulate(BaseTabulate): def __init__(self, ...): # Use ACTIVATION_MAP here
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
deepmd/pt/utils/tabulate.py(1 hunks)deepmd/utils/tabulate.py(1 hunks)source/tests/pt/test_tabulate.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- source/tests/pt/test_tabulate.py
🧰 Additional context used
🪛 Ruff
deepmd/pt/utils/tabulate.py
54-54: Do not use mutable data structures for argument defaults
Replace with None; initialize within function
(B006)
55-55: Do not perform function call ActivationFn in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable
(B008)
deepmd/utils/tabulate.py
354-354: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks
(B019)
se_a, se_atten(DPA1), se_t, se_r
Summary by CodeRabbit
Release Notes
New Features
enable_compressionmethods to various classes, allowing users to enable and configure compression settings.Bug Fixes
Tests
Documentation