Skip to content

Commit 778d420

Browse files
Copilotdzzz2001
andauthored
Fix CUDA GPU detection crash when running with device=cpu on machines without GPU (#6759)
* Initial plan * Fix CUDA GPU detection when device is set to CPU - Add runtime check for PARAM.inp.device before calling GPU detection functions - Prevent cudaErrcheck from exiting when no GPU is available but device is set to "cpu" - Apply same logic as non-MPI version to MPI version in output_device_info Co-authored-by: dzzz2001 <153698752+dzzz2001@users.noreply.github.com> * Implement device=auto with safe GPU probe and MPI broadcast - Add probe_gpu_availability() function that safely checks GPU without exiting - Refactor get_device_flag() to support device=auto with clear user feedback - Implement MPI broadcast for consistent device selection across all ranks - Add safety check in get_device_kpar() to prevent GPU calls when unavailable - Add informative messages for device selection (INFO/WARNING) Co-authored-by: dzzz2001 <153698752+dzzz2001@users.noreply.github.com> * Fix MPI synchronization issue in get_device_kpar - Remove redundant probe_gpu_availability check that could cause MPI inconsistency - Function is only called when device=gpu is already confirmed - Add comment clarifying the function should only be called in GPU mode Co-authored-by: dzzz2001 <153698752+dzzz2001@users.noreply.github.com> * 注释掉 get_device_flag 中所有 std::cout 输出,避免冗余信息打印 --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: dzzz2001 <153698752+dzzz2001@users.noreply.github.com> Co-authored-by: dzzz2001 <dzcpku2001@163.com>
1 parent 6b96629 commit 778d420

File tree

3 files changed

+86
-51
lines changed

3 files changed

+86
-51
lines changed

source/source_base/module_device/device.cpp

Lines changed: 66 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
#include "device.h"
32

43
#include "source_base/tool_quit.h"
@@ -147,58 +146,76 @@ int set_device_by_rank(const MPI_Comm mpi_comm) {
147146

148147
#endif
149148

150-
std::string get_device_flag(const std::string &device,
151-
const std::string &basis_type) {
152-
if (device == "cpu") {
153-
return "cpu"; // no extra checks required
154-
}
155-
std::string error_message;
156-
if (device != "auto" and device != "gpu")
157-
{
158-
error_message += "Parameter \"device\" can only be set to \"cpu\" or \"gpu\"!";
159-
ModuleBase::WARNING_QUIT("device", error_message);
160-
}
161-
162-
// Get available GPU count
163-
int device_count = -1;
164-
#if ((defined __CUDA) || (defined __ROCM))
149+
bool probe_gpu_availability() {
165150
#if defined(__CUDA)
166-
cudaGetDeviceCount(&device_count);
151+
int device_count = 0;
152+
// Directly call cudaGetDeviceCount without cudaErrcheck to prevent program exit
153+
cudaError_t error_id = cudaGetDeviceCount(&device_count);
154+
if (error_id == cudaSuccess && device_count > 0) {
155+
return true;
156+
}
157+
return false;
167158
#elif defined(__ROCM)
168-
hipGetDeviceCount(&device_count);
169-
/***auto start_time = std::chrono::high_resolution_clock::now();
170-
std::cout << "Starting hipGetDeviceCount.." << std::endl;
171-
auto end_time = std::chrono::high_resolution_clock::now();
172-
auto duration = std::chrono::duration_cast<std::chrono::duration<double>>(end_time - start_time);
173-
std::cout << "hipGetDeviceCount took " << duration.count() << "seconds" << std::endl;***/
159+
int device_count = 0;
160+
hipError_t error_id = hipGetDeviceCount(&device_count);
161+
if (error_id == hipSuccess && device_count > 0) {
162+
return true;
163+
}
164+
return false;
165+
#else
166+
// If not compiled with GPU support, GPU is not available
167+
return false;
174168
#endif
175-
if (device_count <= 0)
176-
{
177-
error_message += "Cannot find GPU on this computer!\n";
178169
}
179-
#else // CPU only
180-
error_message += "ABACUS is built with CPU support only. Please rebuild with GPU support.\n";
181-
#endif
182170

183-
if (basis_type == "lcao_in_pw") {
184-
error_message +=
185-
"The GPU currently does not support the basis type \"lcao_in_pw\"!";
186-
}
187-
if(error_message.empty())
188-
{
189-
return "gpu"; // possibly automatically set to GPU
190-
}
191-
else if (device == "gpu")
192-
{
193-
ModuleBase::WARNING_QUIT("device", error_message);
194-
}
195-
else { return "cpu";
196-
}
171+
std::string get_device_flag(const std::string &device,
172+
const std::string &basis_type) {
173+
// 1. Validate input string
174+
if (device != "cpu" && device != "gpu" && device != "auto") {
175+
ModuleBase::WARNING_QUIT("device", "Parameter \"device\" can only be set to \"cpu\", \"gpu\", or \"auto\"!");
176+
}
177+
178+
// NOTE: This function is called only on rank 0 during input parsing.
179+
// The result will be broadcast to other ranks via the standard bcast mechanism.
180+
// DO NOT use MPI_Bcast here as other ranks are not in this code path.
181+
182+
std::string result = "cpu";
183+
184+
if (device == "gpu") {
185+
if (probe_gpu_availability()) {
186+
result = "gpu";
187+
// std::cout << " INFO: 'device=gpu' specified. GPU will be used." << std::endl;
188+
} else {
189+
ModuleBase::WARNING_QUIT("device", "Device is set to 'gpu', but no available GPU was found. Please check your hardware/drivers or set 'device=cpu'.");
190+
}
191+
} else if (device == "auto") {
192+
if (probe_gpu_availability()) {
193+
result = "gpu";
194+
// std::cout << " INFO: 'device=auto' specified. GPU detected and will be used." << std::endl;
195+
} else {
196+
result = "cpu";
197+
// std::cout << " WARNING: 'device=auto' specified, but no GPU was found. Falling back to CPU." << std::endl;
198+
// std::cout << " To suppress this warning, please explicitly set 'device=cpu' in your input." << std::endl;
199+
}
200+
} else { // device == "cpu"
201+
result = "cpu";
202+
// std::cout << " INFO: 'device=cpu' specified. CPU will be used." << std::endl;
203+
}
204+
205+
// 2. Final check for incompatible basis type
206+
if (result == "gpu" && basis_type == "lcao_in_pw") {
207+
ModuleBase::WARNING_QUIT("device", "The GPU currently does not support the basis type \"lcao_in_pw\"!");
208+
}
209+
210+
// 3. Return the final decision
211+
return result;
197212
}
198213

199214
int get_device_kpar(const int& kpar, const int& bndpar)
200215
{
201216
#if __MPI && (__CUDA || __ROCM)
217+
// This function should only be called when device mode is GPU
218+
// The device decision has already been made by get_device_flag()
202219
int temp_nproc = 0;
203220
int new_kpar = kpar;
204221
MPI_Comm_size(MPI_COMM_WORLD, &temp_nproc);
@@ -213,15 +230,15 @@ int get_device_kpar(const int& kpar, const int& bndpar)
213230

214231
int device_num = -1;
215232
#if defined(__CUDA)
216-
cudaGetDeviceCount(&device_num); // get the number of GPU devices of current node
217-
cudaSetDevice(node_rank % device_num); // band the CPU processor to the devices
233+
cudaErrcheck(cudaGetDeviceCount(&device_num)); // get the number of GPU devices of current node
234+
cudaErrcheck(cudaSetDevice(node_rank % device_num)); // bind the CPU processor to the devices
218235
#elif defined(__ROCM)
219-
hipGetDeviceCount(&device_num);
220-
hipSetDevice(node_rank % device_num);
236+
hipErrcheck(hipGetDeviceCount(&device_num));
237+
hipErrcheck(hipSetDevice(node_rank % device_num));
221238
#endif
222-
return new_kpar;
239+
return new_kpar;
223240
#endif
224-
return kpar;
241+
return kpar;
225242
}
226243

227244
} // end of namespace information

source/source_base/module_device/device.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ void output_device_info(std::ostream& output);
4444
*/
4545
int get_device_kpar(const int& kpar, const int& bndpar);
4646

47+
/**
48+
* @brief Safely probes for GPU availability without exiting on error.
49+
* @return True if at least one GPU is found and usable, false otherwise.
50+
*/
51+
bool probe_gpu_availability();
52+
4753
/**
4854
* @brief Get the device flag object
4955
* for source_io PARAM.inp.device

source/source_base/module_device/output_device.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,13 @@ void output_device_info(std::ostream &output)
115115
int local_rank = get_node_rank_with_mpi_shared(MPI_COMM_WORLD);
116116

117117
// Get local hardware info
118-
int local_gpu_count = local_rank == 0 ? get_device_num("gpu") : 0;
118+
int local_gpu_count = 0;
119+
#if defined(__CUDA) || defined(__ROCM)
120+
if(PARAM.inp.device == "gpu" && local_rank == 0)
121+
{
122+
local_gpu_count = get_device_num("gpu");
123+
}
124+
#endif
119125
int local_cpu_sockets = local_rank == 0 ? get_device_num("cpu") : 0;
120126

121127
// Prepare vectors to gather data from all ranks
@@ -133,7 +139,13 @@ void output_device_info(std::ostream &output)
133139

134140
// Get device model names (from rank 0 node)
135141
std::string cpu_name = get_device_name("cpu");
136-
std::string gpu_name = get_device_name("gpu");
142+
std::string gpu_name;
143+
#if defined(__CUDA) || defined(__ROCM)
144+
if(PARAM.inp.device == "gpu" && total_gpus > 0)
145+
{
146+
gpu_name = get_device_name("gpu");
147+
}
148+
#endif
137149

138150
// Output all collected information
139151
output << " RUNNING WITH DEVICE : " << "CPU" << " / "

0 commit comments

Comments
 (0)