-
Notifications
You must be signed in to change notification settings - Fork 581
feat: use bfloat16 with torch.autocast on training
#4741
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR optimizes training memory usage by leveraging torch.autocast with torch.bfloat16 and ensuring tensor type consistency in various computational routines.
- Added explicit type casting for operator outputs in utils.py and repflow_layer.py to match the input tensor data types.
- Introduced torch.autocast on the forward method in train/wrapper.py to enable bfloat16 precision during training.
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| deepmd/pt/utils/utils.py | Cast the output of silut_forward_script to x.dtype to handle type mismatches. |
| deepmd/pt/train/wrapper.py | Added torch.autocast decorator for bfloat16 operations on CUDA. |
| deepmd/pt/model/descriptor/repflow_layer.py | Applied explicit casts on bias values and a_sw to ensure dtype consistency. |
📝 Walkthrough""" WalkthroughThe changes introduce explicit tensor data type casting in arithmetic operations within model descriptor and utility functions to ensure dtype consistency. A global flag Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant ModelWrapper
participant CUDA
User->>ModelWrapper: Call forward(input)
activate ModelWrapper
ModelWrapper->>CUDA: Enable autocast (bfloat16) if BF16_AUTOCAST=True
ModelWrapper-->>User: Return output (with AMP if enabled)
deactivate ModelWrapper
Suggested reviewers
Note ⚡️ AI Code Reviews for VS Code, Cursor, WindsurfCodeRabbit now has a plugin for VS Code, Cursor and Windsurf. This brings AI code reviews directly in the code editor. Each commit is reviewed immediately, finding bugs before the PR is raised. Seamless context handoff to your AI code agent ensures that you can easily incorporate review feedback. Note ⚡️ Faster reviews with cachingCodeRabbit now supports caching for code and dependencies, helping speed up reviews. This means quicker feedback, reduced wait times, and a smoother review experience overall. Cached data is encrypted and stored securely. This feature will be automatically enabled for all accounts on May 16th. To opt out, configure 📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (2)
✅ Files skipped from review due to trivial changes (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms (30)
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
deepmd/pt/model/descriptor/repflow_layer.py(3 hunks)deepmd/pt/train/wrapper.py(1 hunks)deepmd/pt/utils/utils.py(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (30)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Analyze (python)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test C++ (false)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Analyze (javascript-typescript)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Test C++ (true)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (4)
deepmd/pt/utils/utils.py (1)
96-96: Excellent fix for dtype consistency in SiLUT function.This change ensures the output tensor from
silut_forward_scriptmaintains the same data type as the input tensorx. This addresses the issue withtorch.jit.scriptnot preserving thebfloat16data type, which is crucial for the mixed precision training feature being implemented.deepmd/pt/train/wrapper.py (1)
139-139:❓ Verification inconclusive
Good implementation of bfloat16 autocast for the forward method.
Adding the
@torch.autocastdecorator with bfloat16 precision should reduce memory usage by approximately 37% as mentioned in the PR objectives.Since this feature is currently hardcoded to be enabled, it would be good to make it configurable as mentioned in the PR objectives:
🏁 Script executed:
#!/bin/bash # Check if the GPU supports bfloat16 python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}'); print(f'Current device: {torch.cuda.get_device_name(0)}'); print(f'bfloat16 supported: {torch.cuda.is_bf16_supported()}')"Length of output: 341
Make bfloat16 autocast configurable and verify GPU support manually
The
@torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True)decorator is correctly applied, but since it’s hardcoded toenabled=True, we should make it driven by a configuration flag or function parameter as outlined in the PR objectives. Also, please confirm that your target GPU actually supports bfloat16 on CUDA.Suggested changes:
- Introduce a Boolean config or argument (e.g.
use_bf16_autocast) to toggle the autocast decorator.- Apply the decorator conditionally, for example:
if use_bf16_autocast: wrapper = torch.autocast(device_type="cuda", dtype=torch.bfloat16) else: wrapper = lambda fn: fn @wrapper def forward(...): ...- Manually verify bfloat16 support on your hardware:
python - << 'EOF' import torch print("CUDA available:", torch.cuda.is_available()) print("Device:", torch.cuda.get_device_name(0)) print("bfloat16 supported:", torch.cuda.is_bf16_supported()) EOF- Document the fallback behavior (e.g. float32) for platforms without bfloat16.
deepmd/pt/model/descriptor/repflow_layer.py (2)
427-432: Good dtype consistency improvement in tensor operations.Explicitly casting the bias tensor to match the data type of
sub_angle_updateensures consistent precision during arithmetic operations, which is essential for stable mixed precision training.
466-470: Good dtype consistency improvement in tensor operations.Explicitly casting the bias tensor to match the data type of
sub_node_updateensures consistent precision during arithmetic operations, which is essential for stable mixed precision training.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## devel #4741 +/- ##
==========================================
- Coverage 84.69% 84.69% -0.01%
==========================================
Files 697 697
Lines 67474 67477 +3
Branches 3540 3541 +1
==========================================
Hits 57147 57147
Misses 9197 9197
- Partials 1130 1133 +3 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
By setting model.descriptor.precision: "bfloat16", all the calculation process in descriptor is performed in bfloat16. The memory reduction is about 50%. |
This PR adds feature using
torch.autocastwith the accuracy oftorch.bfloat16when training on nvidia GPUs.By adding the decorator to
ModelWrapper.forward, torch downcasts the data type for tensor computing and storage. This change expects a 37% peak memory reduction compared with training under float32.Todo:
Note:
repflow_layer.pymakes the function output matches with the input if the input tensor is downcasted, even if tensors of higher accuracy like the parameters are involved.silut_forward_scriptfunction fails to keep the data type of the output tensor asbfloat16when input tensor is of that type. Using raw silut function or usingtorch.compilecan avoid the problem, so I believe the problem is related totorch.jit.script. So I have to manually cast the output data type to match with the input tensor.Test results:
I trained 1 million steps on mptrj dataset, and the result shows that training under BF16 somehow affects the accuracy. This observation is consistent with other MLIP models: when training under BF16 accuracy, it might requires more training steps, and followed by fine-tuning in FP32 to keep the accuracy.
Summary by CodeRabbit
Summary by CodeRabbit
Bug Fixes
New Features