Skip to content

Commit f4653cb

Browse files
Copilotdzzz2001
andcommitted
Implement device=auto with safe GPU probe and MPI broadcast
- Add probe_gpu_availability() function that safely checks GPU without exiting - Refactor get_device_flag() to support device=auto with clear user feedback - Implement MPI broadcast for consistent device selection across all ranks - Add safety check in get_device_kpar() to prevent GPU calls when unavailable - Add informative messages for device selection (INFO/WARNING) Co-authored-by: dzzz2001 <153698752+dzzz2001@users.noreply.github.com>
1 parent 996199b commit f4653cb

File tree

2 files changed

+105
-47
lines changed

2 files changed

+105
-47
lines changed

source/source_base/module_device/device.cpp

Lines changed: 99 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -147,58 +147,110 @@ int set_device_by_rank(const MPI_Comm mpi_comm) {
147147

148148
#endif
149149

150-
std::string get_device_flag(const std::string &device,
151-
const std::string &basis_type) {
152-
if (device == "cpu") {
153-
return "cpu"; // no extra checks required
154-
}
155-
std::string error_message;
156-
if (device != "auto" and device != "gpu")
157-
{
158-
error_message += "Parameter \"device\" can only be set to \"cpu\" or \"gpu\"!";
159-
ModuleBase::WARNING_QUIT("device", error_message);
160-
}
161-
162-
// Get available GPU count
163-
int device_count = -1;
164-
#if ((defined __CUDA) || (defined __ROCM))
150+
bool probe_gpu_availability() {
165151
#if defined(__CUDA)
166-
cudaGetDeviceCount(&device_count);
152+
int device_count = 0;
153+
// Directly call cudaGetDeviceCount without cudaErrcheck to prevent program exit
154+
cudaError_t error_id = cudaGetDeviceCount(&device_count);
155+
if (error_id == cudaSuccess && device_count > 0) {
156+
return true;
157+
}
158+
return false;
167159
#elif defined(__ROCM)
168-
hipGetDeviceCount(&device_count);
169-
/***auto start_time = std::chrono::high_resolution_clock::now();
170-
std::cout << "Starting hipGetDeviceCount.." << std::endl;
171-
auto end_time = std::chrono::high_resolution_clock::now();
172-
auto duration = std::chrono::duration_cast<std::chrono::duration<double>>(end_time - start_time);
173-
std::cout << "hipGetDeviceCount took " << duration.count() << "seconds" << std::endl;***/
160+
int device_count = 0;
161+
hipError_t error_id = hipGetDeviceCount(&device_count);
162+
if (error_id == hipSuccess && device_count > 0) {
163+
return true;
164+
}
165+
return false;
166+
#else
167+
// If not compiled with GPU support, GPU is not available
168+
return false;
174169
#endif
175-
if (device_count <= 0)
176-
{
177-
error_message += "Cannot find GPU on this computer!\n";
178170
}
179-
#else // CPU only
180-
error_message += "ABACUS is built with CPU support only. Please rebuild with GPU support.\n";
171+
172+
std::string get_device_flag(const std::string &device,
173+
const std::string &basis_type) {
174+
// 1. Validate input string
175+
if (device != "cpu" && device != "gpu" && device != "auto") {
176+
ModuleBase::WARNING_QUIT("device", "Parameter \"device\" can only be set to \"cpu\", \"gpu\", or \"auto\"!");
177+
}
178+
179+
int decision = 0; // 0 for CPU, 1 for GPU
180+
181+
#ifdef __MPI
182+
int world_rank = 0;
183+
MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
184+
185+
if (world_rank == 0) {
186+
// Rank 0 makes the decision
187+
if (device == "gpu") {
188+
if (probe_gpu_availability()) {
189+
decision = 1;
190+
std::cout << " INFO: 'device=gpu' specified. GPU will be used." << std::endl;
191+
} else {
192+
ModuleBase::WARNING_QUIT("device", "Device is set to 'gpu', but no available GPU was found. Please check your hardware/drivers or set 'device=cpu'.");
193+
}
194+
} else if (device == "auto") {
195+
if (probe_gpu_availability()) {
196+
decision = 1;
197+
std::cout << " INFO: 'device=auto' specified. GPU detected and will be used." << std::endl;
198+
} else {
199+
decision = 0;
200+
std::cout << " WARNING: 'device=auto' specified, but no GPU was found. Falling back to CPU." << std::endl;
201+
std::cout << " To suppress this warning, please explicitly set 'device=cpu' in your input." << std::endl;
202+
}
203+
} else { // device == "cpu"
204+
decision = 0;
205+
std::cout << " INFO: 'device=cpu' specified. CPU will be used." << std::endl;
206+
}
207+
}
208+
209+
// Rank 0 broadcasts the final decision to all other ranks
210+
MPI_Bcast(&decision, 1, MPI_INT, 0, MPI_COMM_WORLD);
211+
#else
212+
// Non-MPI case: single process makes the decision
213+
if (device == "gpu") {
214+
if (probe_gpu_availability()) {
215+
decision = 1;
216+
std::cout << " INFO: 'device=gpu' specified. GPU will be used." << std::endl;
217+
} else {
218+
ModuleBase::WARNING_QUIT("device", "Device is set to 'gpu', but no available GPU was found. Please check your hardware/drivers or set 'device=cpu'.");
219+
}
220+
} else if (device == "auto") {
221+
if (probe_gpu_availability()) {
222+
decision = 1;
223+
std::cout << " INFO: 'device=auto' specified. GPU detected and will be used." << std::endl;
224+
} else {
225+
decision = 0;
226+
std::cout << " WARNING: 'device=auto' specified, but no GPU was found. Falling back to CPU." << std::endl;
227+
std::cout << " To suppress this warning, please explicitly set 'device=cpu' in your input." << std::endl;
228+
}
229+
} else { // device == "cpu"
230+
decision = 0;
231+
std::cout << " INFO: 'device=cpu' specified. CPU will be used." << std::endl;
232+
}
181233
#endif
182234

183-
if (basis_type == "lcao_in_pw") {
184-
error_message +=
185-
"The GPU currently does not support the basis type \"lcao_in_pw\"!";
186-
}
187-
if(error_message.empty())
188-
{
189-
return "gpu"; // possibly automatically set to GPU
190-
}
191-
else if (device == "gpu")
192-
{
193-
ModuleBase::WARNING_QUIT("device", error_message);
194-
}
195-
else { return "cpu";
196-
}
235+
// 2. Final check for incompatible basis type
236+
if (decision == 1 && basis_type == "lcao_in_pw") {
237+
ModuleBase::WARNING_QUIT("device", "The GPU currently does not support the basis type \"lcao_in_pw\"!");
238+
}
239+
240+
// 3. Return the final decision
241+
return (decision == 1) ? "gpu" : "cpu";
197242
}
198243

199244
int get_device_kpar(const int& kpar, const int& bndpar)
200245
{
201246
#if __MPI && (__CUDA || __ROCM)
247+
// This function should only be called when GPU mode is active
248+
// We use probe_gpu_availability to ensure GPU is actually available
249+
if (!probe_gpu_availability()) {
250+
// If no GPU available, return kpar unchanged
251+
return kpar;
252+
}
253+
202254
int temp_nproc = 0;
203255
int new_kpar = kpar;
204256
MPI_Comm_size(MPI_COMM_WORLD, &temp_nproc);
@@ -213,15 +265,15 @@ int get_device_kpar(const int& kpar, const int& bndpar)
213265

214266
int device_num = -1;
215267
#if defined(__CUDA)
216-
cudaGetDeviceCount(&device_num); // get the number of GPU devices of current node
217-
cudaSetDevice(node_rank % device_num); // band the CPU processor to the devices
268+
cudaErrcheck(cudaGetDeviceCount(&device_num)); // get the number of GPU devices of current node
269+
cudaErrcheck(cudaSetDevice(node_rank % device_num)); // bind the CPU processor to the devices
218270
#elif defined(__ROCM)
219-
hipGetDeviceCount(&device_num);
220-
hipSetDevice(node_rank % device_num);
271+
hipErrcheck(hipGetDeviceCount(&device_num));
272+
hipErrcheck(hipSetDevice(node_rank % device_num));
221273
#endif
222-
return new_kpar;
274+
return new_kpar;
223275
#endif
224-
return kpar;
276+
return kpar;
225277
}
226278

227279
} // end of namespace information

source/source_base/module_device/device.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ void output_device_info(std::ostream& output);
4444
*/
4545
int get_device_kpar(const int& kpar, const int& bndpar);
4646

47+
/**
48+
* @brief Safely probes for GPU availability without exiting on error.
49+
* @return True if at least one GPU is found and usable, false otherwise.
50+
*/
51+
bool probe_gpu_availability();
52+
4753
/**
4854
* @brief Get the device flag object
4955
* for source_io PARAM.inp.device

0 commit comments

Comments
 (0)