diff --git a/source/source_base/module_device/device.cpp b/source/source_base/module_device/device.cpp index 96deae8baf..bddfbaa62e 100644 --- a/source/source_base/module_device/device.cpp +++ b/source/source_base/module_device/device.cpp @@ -1,4 +1,3 @@ - #include "device.h" #include "source_base/tool_quit.h" @@ -147,58 +146,76 @@ int set_device_by_rank(const MPI_Comm mpi_comm) { #endif -std::string get_device_flag(const std::string &device, - const std::string &basis_type) { -if (device == "cpu") { - return "cpu"; // no extra checks required -} -std::string error_message; -if (device != "auto" and device != "gpu") -{ - error_message += "Parameter \"device\" can only be set to \"cpu\" or \"gpu\"!"; - ModuleBase::WARNING_QUIT("device", error_message); -} - -// Get available GPU count -int device_count = -1; -#if ((defined __CUDA) || (defined __ROCM)) +bool probe_gpu_availability() { #if defined(__CUDA) -cudaGetDeviceCount(&device_count); + int device_count = 0; + // Directly call cudaGetDeviceCount without cudaErrcheck to prevent program exit + cudaError_t error_id = cudaGetDeviceCount(&device_count); + if (error_id == cudaSuccess && device_count > 0) { + return true; + } + return false; #elif defined(__ROCM) -hipGetDeviceCount(&device_count); -/***auto start_time = std::chrono::high_resolution_clock::now(); -std::cout << "Starting hipGetDeviceCount.." << std::endl; -auto end_time = std::chrono::high_resolution_clock::now(); -auto duration = std::chrono::duration_cast>(end_time - start_time); -std::cout << "hipGetDeviceCount took " << duration.count() << "seconds" << std::endl;***/ + int device_count = 0; + hipError_t error_id = hipGetDeviceCount(&device_count); + if (error_id == hipSuccess && device_count > 0) { + return true; + } + return false; +#else + // If not compiled with GPU support, GPU is not available + return false; #endif -if (device_count <= 0) -{ - error_message += "Cannot find GPU on this computer!\n"; } -#else // CPU only -error_message += "ABACUS is built with CPU support only. Please rebuild with GPU support.\n"; -#endif -if (basis_type == "lcao_in_pw") { - error_message += - "The GPU currently does not support the basis type \"lcao_in_pw\"!"; -} -if(error_message.empty()) -{ - return "gpu"; // possibly automatically set to GPU -} -else if (device == "gpu") -{ - ModuleBase::WARNING_QUIT("device", error_message); -} -else { return "cpu"; -} +std::string get_device_flag(const std::string &device, + const std::string &basis_type) { + // 1. Validate input string + if (device != "cpu" && device != "gpu" && device != "auto") { + ModuleBase::WARNING_QUIT("device", "Parameter \"device\" can only be set to \"cpu\", \"gpu\", or \"auto\"!"); + } + + // NOTE: This function is called only on rank 0 during input parsing. + // The result will be broadcast to other ranks via the standard bcast mechanism. + // DO NOT use MPI_Bcast here as other ranks are not in this code path. + + std::string result = "cpu"; + + if (device == "gpu") { + if (probe_gpu_availability()) { + result = "gpu"; + // std::cout << " INFO: 'device=gpu' specified. GPU will be used." << std::endl; + } else { + ModuleBase::WARNING_QUIT("device", "Device is set to 'gpu', but no available GPU was found. Please check your hardware/drivers or set 'device=cpu'."); + } + } else if (device == "auto") { + if (probe_gpu_availability()) { + result = "gpu"; + // std::cout << " INFO: 'device=auto' specified. GPU detected and will be used." << std::endl; + } else { + result = "cpu"; + // std::cout << " WARNING: 'device=auto' specified, but no GPU was found. Falling back to CPU." << std::endl; + // std::cout << " To suppress this warning, please explicitly set 'device=cpu' in your input." << std::endl; + } + } else { // device == "cpu" + result = "cpu"; + // std::cout << " INFO: 'device=cpu' specified. CPU will be used." << std::endl; + } + + // 2. Final check for incompatible basis type + if (result == "gpu" && basis_type == "lcao_in_pw") { + ModuleBase::WARNING_QUIT("device", "The GPU currently does not support the basis type \"lcao_in_pw\"!"); + } + + // 3. Return the final decision + return result; } int get_device_kpar(const int& kpar, const int& bndpar) { #if __MPI && (__CUDA || __ROCM) + // This function should only be called when device mode is GPU + // The device decision has already been made by get_device_flag() int temp_nproc = 0; int new_kpar = kpar; MPI_Comm_size(MPI_COMM_WORLD, &temp_nproc); @@ -213,15 +230,15 @@ int get_device_kpar(const int& kpar, const int& bndpar) int device_num = -1; #if defined(__CUDA) - cudaGetDeviceCount(&device_num); // get the number of GPU devices of current node - cudaSetDevice(node_rank % device_num); // band the CPU processor to the devices + cudaErrcheck(cudaGetDeviceCount(&device_num)); // get the number of GPU devices of current node + cudaErrcheck(cudaSetDevice(node_rank % device_num)); // bind the CPU processor to the devices #elif defined(__ROCM) - hipGetDeviceCount(&device_num); - hipSetDevice(node_rank % device_num); + hipErrcheck(hipGetDeviceCount(&device_num)); + hipErrcheck(hipSetDevice(node_rank % device_num)); #endif - return new_kpar; + return new_kpar; #endif - return kpar; + return kpar; } } // end of namespace information diff --git a/source/source_base/module_device/device.h b/source/source_base/module_device/device.h index 7b8dd0c6ae..f6dfd3f207 100644 --- a/source/source_base/module_device/device.h +++ b/source/source_base/module_device/device.h @@ -44,6 +44,12 @@ void output_device_info(std::ostream& output); */ int get_device_kpar(const int& kpar, const int& bndpar); +/** + * @brief Safely probes for GPU availability without exiting on error. + * @return True if at least one GPU is found and usable, false otherwise. + */ +bool probe_gpu_availability(); + /** * @brief Get the device flag object * for source_io PARAM.inp.device diff --git a/source/source_base/module_device/output_device.cpp b/source/source_base/module_device/output_device.cpp index 1d0f018814..41b4c6d082 100644 --- a/source/source_base/module_device/output_device.cpp +++ b/source/source_base/module_device/output_device.cpp @@ -115,7 +115,13 @@ void output_device_info(std::ostream &output) int local_rank = get_node_rank_with_mpi_shared(MPI_COMM_WORLD); // Get local hardware info - int local_gpu_count = local_rank == 0 ? get_device_num("gpu") : 0; + int local_gpu_count = 0; + #if defined(__CUDA) || defined(__ROCM) + if(PARAM.inp.device == "gpu" && local_rank == 0) + { + local_gpu_count = get_device_num("gpu"); + } + #endif int local_cpu_sockets = local_rank == 0 ? get_device_num("cpu") : 0; // Prepare vectors to gather data from all ranks @@ -133,7 +139,13 @@ void output_device_info(std::ostream &output) // Get device model names (from rank 0 node) std::string cpu_name = get_device_name("cpu"); - std::string gpu_name = get_device_name("gpu"); + std::string gpu_name; + #if defined(__CUDA) || defined(__ROCM) + if(PARAM.inp.device == "gpu" && total_gpus > 0) + { + gpu_name = get_device_name("gpu"); + } + #endif // Output all collected information output << " RUNNING WITH DEVICE : " << "CPU" << " / "