@@ -147,58 +147,110 @@ int set_device_by_rank(const MPI_Comm mpi_comm) {
147147
148148#endif
149149
150- std::string get_device_flag (const std::string &device,
151- const std::string &basis_type) {
152- if (device == " cpu" ) {
153- return " cpu" ; // no extra checks required
154- }
155- std::string error_message;
156- if (device != " auto" and device != " gpu" )
157- {
158- error_message += " Parameter \" device\" can only be set to \" cpu\" or \" gpu\" !" ;
159- ModuleBase::WARNING_QUIT (" device" , error_message);
160- }
161-
162- // Get available GPU count
163- int device_count = -1 ;
164- #if ((defined __CUDA) || (defined __ROCM))
150+ bool probe_gpu_availability () {
165151#if defined(__CUDA)
166- cudaGetDeviceCount (&device_count);
152+ int device_count = 0 ;
153+ // Directly call cudaGetDeviceCount without cudaErrcheck to prevent program exit
154+ cudaError_t error_id = cudaGetDeviceCount (&device_count);
155+ if (error_id == cudaSuccess && device_count > 0 ) {
156+ return true ;
157+ }
158+ return false ;
167159#elif defined(__ROCM)
168- hipGetDeviceCount (&device_count);
169- /* **auto start_time = std::chrono::high_resolution_clock::now();
170- std::cout << "Starting hipGetDeviceCount.." << std::endl;
171- auto end_time = std::chrono::high_resolution_clock::now();
172- auto duration = std::chrono::duration_cast<std::chrono::duration<double>>(end_time - start_time);
173- std::cout << "hipGetDeviceCount took " << duration.count() << "seconds" << std::endl;***/
160+ int device_count = 0 ;
161+ hipError_t error_id = hipGetDeviceCount (&device_count);
162+ if (error_id == hipSuccess && device_count > 0 ) {
163+ return true ;
164+ }
165+ return false ;
166+ #else
167+ // If not compiled with GPU support, GPU is not available
168+ return false ;
174169#endif
175- if (device_count <= 0 )
176- {
177- error_message += " Cannot find GPU on this computer!\n " ;
178170}
179- #else // CPU only
180- error_message += " ABACUS is built with CPU support only. Please rebuild with GPU support.\n " ;
171+
172+ std::string get_device_flag (const std::string &device,
173+ const std::string &basis_type) {
174+ // 1. Validate input string
175+ if (device != " cpu" && device != " gpu" && device != " auto" ) {
176+ ModuleBase::WARNING_QUIT (" device" , " Parameter \" device\" can only be set to \" cpu\" , \" gpu\" , or \" auto\" !" );
177+ }
178+
179+ int decision = 0 ; // 0 for CPU, 1 for GPU
180+
181+ #ifdef __MPI
182+ int world_rank = 0 ;
183+ MPI_Comm_rank (MPI_COMM_WORLD, &world_rank);
184+
185+ if (world_rank == 0 ) {
186+ // Rank 0 makes the decision
187+ if (device == " gpu" ) {
188+ if (probe_gpu_availability ()) {
189+ decision = 1 ;
190+ std::cout << " INFO: 'device=gpu' specified. GPU will be used." << std::endl;
191+ } else {
192+ ModuleBase::WARNING_QUIT (" device" , " Device is set to 'gpu', but no available GPU was found. Please check your hardware/drivers or set 'device=cpu'." );
193+ }
194+ } else if (device == " auto" ) {
195+ if (probe_gpu_availability ()) {
196+ decision = 1 ;
197+ std::cout << " INFO: 'device=auto' specified. GPU detected and will be used." << std::endl;
198+ } else {
199+ decision = 0 ;
200+ std::cout << " WARNING: 'device=auto' specified, but no GPU was found. Falling back to CPU." << std::endl;
201+ std::cout << " To suppress this warning, please explicitly set 'device=cpu' in your input." << std::endl;
202+ }
203+ } else { // device == "cpu"
204+ decision = 0 ;
205+ std::cout << " INFO: 'device=cpu' specified. CPU will be used." << std::endl;
206+ }
207+ }
208+
209+ // Rank 0 broadcasts the final decision to all other ranks
210+ MPI_Bcast (&decision, 1 , MPI_INT, 0 , MPI_COMM_WORLD);
211+ #else
212+ // Non-MPI case: single process makes the decision
213+ if (device == " gpu" ) {
214+ if (probe_gpu_availability ()) {
215+ decision = 1 ;
216+ std::cout << " INFO: 'device=gpu' specified. GPU will be used." << std::endl;
217+ } else {
218+ ModuleBase::WARNING_QUIT (" device" , " Device is set to 'gpu', but no available GPU was found. Please check your hardware/drivers or set 'device=cpu'." );
219+ }
220+ } else if (device == " auto" ) {
221+ if (probe_gpu_availability ()) {
222+ decision = 1 ;
223+ std::cout << " INFO: 'device=auto' specified. GPU detected and will be used." << std::endl;
224+ } else {
225+ decision = 0 ;
226+ std::cout << " WARNING: 'device=auto' specified, but no GPU was found. Falling back to CPU." << std::endl;
227+ std::cout << " To suppress this warning, please explicitly set 'device=cpu' in your input." << std::endl;
228+ }
229+ } else { // device == "cpu"
230+ decision = 0 ;
231+ std::cout << " INFO: 'device=cpu' specified. CPU will be used." << std::endl;
232+ }
181233#endif
182234
183- if (basis_type == " lcao_in_pw" ) {
184- error_message +=
185- " The GPU currently does not support the basis type \" lcao_in_pw\" !" ;
186- }
187- if (error_message.empty ())
188- {
189- return " gpu" ; // possibly automatically set to GPU
190- }
191- else if (device == " gpu" )
192- {
193- ModuleBase::WARNING_QUIT (" device" , error_message);
194- }
195- else { return " cpu" ;
196- }
235+ // 2. Final check for incompatible basis type
236+ if (decision == 1 && basis_type == " lcao_in_pw" ) {
237+ ModuleBase::WARNING_QUIT (" device" , " The GPU currently does not support the basis type \" lcao_in_pw\" !" );
238+ }
239+
240+ // 3. Return the final decision
241+ return (decision == 1 ) ? " gpu" : " cpu" ;
197242}
198243
199244int get_device_kpar (const int & kpar, const int & bndpar)
200245{
201246#if __MPI && (__CUDA || __ROCM)
247+ // This function should only be called when GPU mode is active
248+ // We use probe_gpu_availability to ensure GPU is actually available
249+ if (!probe_gpu_availability ()) {
250+ // If no GPU available, return kpar unchanged
251+ return kpar;
252+ }
253+
202254 int temp_nproc = 0 ;
203255 int new_kpar = kpar;
204256 MPI_Comm_size (MPI_COMM_WORLD, &temp_nproc);
@@ -213,15 +265,15 @@ int get_device_kpar(const int& kpar, const int& bndpar)
213265
214266 int device_num = -1 ;
215267#if defined(__CUDA)
216- cudaGetDeviceCount (&device_num); // get the number of GPU devices of current node
217- cudaSetDevice (node_rank % device_num); // band the CPU processor to the devices
268+ cudaErrcheck ( cudaGetDeviceCount (&device_num) ); // get the number of GPU devices of current node
269+ cudaErrcheck ( cudaSetDevice (node_rank % device_num)) ; // bind the CPU processor to the devices
218270#elif defined(__ROCM)
219- hipGetDeviceCount (&device_num);
220- hipSetDevice (node_rank % device_num);
271+ hipErrcheck ( hipGetDeviceCount (&device_num) );
272+ hipErrcheck ( hipSetDevice (node_rank % device_num) );
221273#endif
222- return new_kpar;
274+ return new_kpar;
223275#endif
224- return kpar;
276+ return kpar;
225277}
226278
227279} // end of namespace information
0 commit comments