-
Notifications
You must be signed in to change notification settings - Fork 582
feat(pt): Implement type embedding compression for se_e3_tebd #5059
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Introduced a new method for type embedding compression, allowing precomputation of strip-mode type embeddings for all type pairs. - Added runtime checks to ensure compatibility with the strip input mode and the existence of necessary filter layers. - Updated the forward method to utilize precomputed type embeddings when available, improving performance during inference. These changes optimize the handling of type embeddings, enhancing the efficiency of the descriptor model.
- Changed the assignment of type embedding data to use `register_buffer`, improving memory management. - Introduced a new variable `embd_tensor` for clarity and maintainability in the embedding computation process. These modifications enhance the structure of the code while maintaining existing functionality.
…escrptBlockSeTTebd
…mbedding handling
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR implements type embedding compression for the se_e3_tebd descriptor. The compression allows the model to pre-compute and store type embedding outputs during compression, avoiding redundant computation during inference.
Key changes:
- Added type embedding network parameter to compression pipeline
- Pre-computes type embeddings for all type pairs and stores in
type_embd_databuffer - Modified forward pass to use pre-computed embeddings when compression is enabled
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
📝 WalkthroughWalkthroughModified SE-T TEBD descriptor compression to restrict compression to "strip" input mode, introduced a Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Descriptor as DescrptSeTTebd
participant Block as DescrptBlockSeTTebd
participant Net as TypeEmbedNet
User->>Descriptor: enable_compression(type_embedding_net, ...)
rect rgb(200, 220, 255)
Note over Descriptor: Guard check: tebd_input_mode == "strip"
end
Descriptor->>Descriptor: Validate filter_layers_strip exists
Descriptor->>Block: enable_compression(type_embedding_net, ...)
Block->>Net: Compute type embeddings
Net-->>Block: Precomputed embedding tensor
rect rgb(220, 240, 220)
Block->>Block: register_buffer(type_embd_data, embedding)
end
Block-->>Descriptor: Compression enabled
User->>Descriptor: forward(...)
alt compress is True (strip mode)
rect rgb(255, 240, 200)
Note over Block: Use precomputed type_embd_data
Block-->>Descriptor: Type-related embeddings from buffer
end
else compress is False
rect rgb(240, 220, 255)
Note over Block: Compute type embeddings on-the-fly
Block-->>Descriptor: Type-related embeddings computed
end
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly related PRs
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
deepmd/pt/model/descriptor/se_t_tebd.py (1)
1092-1108: Consider simplifying the buffer update pattern.The type embedding compression logic is correct: it precomputes all type-pair embeddings and stores them for efficient lookup during inference. However, the pattern of deleting and re-registering the buffer could be simplified.
Instead of:
if hasattr(self, "type_embd_data"): del self.type_embd_data self.register_buffer("type_embd_data", embd_tensor)Consider directly updating the buffer data:
- if hasattr(self, "type_embd_data"): - del self.type_embd_data - self.register_buffer("type_embd_data", embd_tensor) + self.type_embd_data = embd_tensorSince
type_embd_datais already registered as a buffer in__init__, directly assigning the new tensor will update the buffer while maintaining proper serialization behavior. This approach is simpler and avoids potential issues with buffer management during multipleenable_compressioncalls.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/pt/model/descriptor/se_t_tebd.py(8 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
deepmd/pt/model/descriptor/se_t_tebd.py (2)
deepmd/dpmodel/utils/type_embed.py (1)
TypeEmbedNet(30-221)deepmd/pt/model/network/network.py (2)
TypeEmbedNet(268-358)get_full_embedding(317-332)
🪛 Ruff (0.14.5)
deepmd/pt/model/descriptor/se_t_tebd.py
557-557: Avoid specifying long messages outside the exception class
(TRY003)
1068-1068: Avoid specifying long messages outside the exception class
(TRY003)
1070-1072: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (31)
- GitHub Check: Agent
- GitHub Check: CodeQL analysis (python)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test C++ (true)
- GitHub Check: Test C++ (false)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (5)
deepmd/pt/model/descriptor/se_t_tebd.py (5)
556-557: LGTM: Compression restricted to strip mode.The guard correctly restricts compression to "strip" mode, which aligns with the PR's implementation strategy for type embedding compression.
700-710: LGTM: Type embedding buffer properly registered.The
type_embd_databuffer is correctly registered for storing precomputed type-embedding compression data, with appropriate initialization and device placement.
997-1014: LGTM: Compression path correctly uses precomputed embeddings.The conditional logic properly switches between precomputed
type_embd_data(compressed) and on-the-fly computation (uncompressed). The index calculation for type pairs is correct.Note: The code assumes
idxvalues are within bounds oftt_full. This should be safe given thatntypes_with_paddingdefines both the index range andtype_embd_datadimensions, but consider adding an assertion if defensive checks are desired.
1044-1072: LGTM: Proper validation for type embedding compression requirements.The updated signature and validation checks correctly enforce that:
- Compression only works in "strip" mode
filter_layers_stripmust existThe error messages are clear and help users understand the constraints.
1074-1091: LGTM: Geometric embedding compression setup.The compression configuration for the geometric embedding network is properly initialized with bounds, table config, and data.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #5059 +/- ##
=======================================
Coverage 84.24% 84.24%
=======================================
Files 709 709
Lines 70236 70296 +60
Branches 3623 3619 -4
=======================================
+ Hits 59169 59221 +52
- Misses 9900 9906 +6
- Partials 1167 1169 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
The unit test for this descriptor will be committed through another pr.
Summary by CodeRabbit
Breaking Changes
Improvements
✏️ Tip: You can customize this high-level summary in your review settings.