Skip to content

Commit 04e1159

Browse files
authored
fix(pt): set device for PT C++ (#4261)
Fix #4171. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Improved GPU initialization to ensure the correct device is utilized. - Enhanced error handling for clearer context on exceptions. - **Bug Fixes** - Updated error handling in multiple methods to catch and rethrow specific exceptions. - Added logic to handle communication-related tensors during computation. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent 39cddd4 commit 04e1159

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

source/api_cc/src/DeepPotPT.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ void DeepPotPT::init(const std::string& model,
8080
device = torch::Device(torch::kCPU);
8181
std::cout << "load model from: " << model << " to cpu " << std::endl;
8282
} else {
83+
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
84+
DPErrcheck(DPSetDevice(gpu_id));
85+
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
8386
std::cout << "load model from: " << model << " to gpu " << gpu_id
8487
<< std::endl;
8588
}

0 commit comments

Comments
 (0)