Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 66 additions & 49 deletions source/source_base/module_device/device.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

#include "device.h"

#include "source_base/tool_quit.h"
Expand Down Expand Up @@ -147,58 +146,76 @@ int set_device_by_rank(const MPI_Comm mpi_comm) {

#endif

std::string get_device_flag(const std::string &device,
const std::string &basis_type) {
if (device == "cpu") {
return "cpu"; // no extra checks required
}
std::string error_message;
if (device != "auto" and device != "gpu")
{
error_message += "Parameter \"device\" can only be set to \"cpu\" or \"gpu\"!";
ModuleBase::WARNING_QUIT("device", error_message);
}

// Get available GPU count
int device_count = -1;
#if ((defined __CUDA) || (defined __ROCM))
bool probe_gpu_availability() {
#if defined(__CUDA)
cudaGetDeviceCount(&device_count);
int device_count = 0;
// Directly call cudaGetDeviceCount without cudaErrcheck to prevent program exit
cudaError_t error_id = cudaGetDeviceCount(&device_count);
if (error_id == cudaSuccess && device_count > 0) {
return true;
}
return false;
#elif defined(__ROCM)
hipGetDeviceCount(&device_count);
/***auto start_time = std::chrono::high_resolution_clock::now();
std::cout << "Starting hipGetDeviceCount.." << std::endl;
auto end_time = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::duration<double>>(end_time - start_time);
std::cout << "hipGetDeviceCount took " << duration.count() << "seconds" << std::endl;***/
int device_count = 0;
hipError_t error_id = hipGetDeviceCount(&device_count);
if (error_id == hipSuccess && device_count > 0) {
return true;
}
return false;
#else
// If not compiled with GPU support, GPU is not available
return false;
#endif
if (device_count <= 0)
{
error_message += "Cannot find GPU on this computer!\n";
}
#else // CPU only
error_message += "ABACUS is built with CPU support only. Please rebuild with GPU support.\n";
#endif

if (basis_type == "lcao_in_pw") {
error_message +=
"The GPU currently does not support the basis type \"lcao_in_pw\"!";
}
if(error_message.empty())
{
return "gpu"; // possibly automatically set to GPU
}
else if (device == "gpu")
{
ModuleBase::WARNING_QUIT("device", error_message);
}
else { return "cpu";
}
std::string get_device_flag(const std::string &device,
const std::string &basis_type) {
// 1. Validate input string
if (device != "cpu" && device != "gpu" && device != "auto") {
ModuleBase::WARNING_QUIT("device", "Parameter \"device\" can only be set to \"cpu\", \"gpu\", or \"auto\"!");
}

// NOTE: This function is called only on rank 0 during input parsing.
// The result will be broadcast to other ranks via the standard bcast mechanism.
// DO NOT use MPI_Bcast here as other ranks are not in this code path.

std::string result = "cpu";

if (device == "gpu") {
if (probe_gpu_availability()) {
result = "gpu";
// std::cout << " INFO: 'device=gpu' specified. GPU will be used." << std::endl;
} else {
ModuleBase::WARNING_QUIT("device", "Device is set to 'gpu', but no available GPU was found. Please check your hardware/drivers or set 'device=cpu'.");
}
} else if (device == "auto") {
if (probe_gpu_availability()) {
result = "gpu";
// std::cout << " INFO: 'device=auto' specified. GPU detected and will be used." << std::endl;
} else {
result = "cpu";
// std::cout << " WARNING: 'device=auto' specified, but no GPU was found. Falling back to CPU." << std::endl;
// std::cout << " To suppress this warning, please explicitly set 'device=cpu' in your input." << std::endl;
}
} else { // device == "cpu"
result = "cpu";
// std::cout << " INFO: 'device=cpu' specified. CPU will be used." << std::endl;
}

// 2. Final check for incompatible basis type
if (result == "gpu" && basis_type == "lcao_in_pw") {
ModuleBase::WARNING_QUIT("device", "The GPU currently does not support the basis type \"lcao_in_pw\"!");
}

// 3. Return the final decision
return result;
}

int get_device_kpar(const int& kpar, const int& bndpar)
{
#if __MPI && (__CUDA || __ROCM)
// This function should only be called when device mode is GPU
// The device decision has already been made by get_device_flag()
int temp_nproc = 0;
int new_kpar = kpar;
MPI_Comm_size(MPI_COMM_WORLD, &temp_nproc);
Expand All @@ -213,15 +230,15 @@ int get_device_kpar(const int& kpar, const int& bndpar)

int device_num = -1;
#if defined(__CUDA)
cudaGetDeviceCount(&device_num); // get the number of GPU devices of current node
cudaSetDevice(node_rank % device_num); // band the CPU processor to the devices
cudaErrcheck(cudaGetDeviceCount(&device_num)); // get the number of GPU devices of current node
cudaErrcheck(cudaSetDevice(node_rank % device_num)); // bind the CPU processor to the devices
#elif defined(__ROCM)
hipGetDeviceCount(&device_num);
hipSetDevice(node_rank % device_num);
hipErrcheck(hipGetDeviceCount(&device_num));
hipErrcheck(hipSetDevice(node_rank % device_num));
#endif
return new_kpar;
return new_kpar;
#endif
return kpar;
return kpar;
}

} // end of namespace information
Expand Down
6 changes: 6 additions & 0 deletions source/source_base/module_device/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ void output_device_info(std::ostream& output);
*/
int get_device_kpar(const int& kpar, const int& bndpar);

/**
* @brief Safely probes for GPU availability without exiting on error.
* @return True if at least one GPU is found and usable, false otherwise.
*/
bool probe_gpu_availability();

/**
* @brief Get the device flag object
* for source_io PARAM.inp.device
Expand Down
16 changes: 14 additions & 2 deletions source/source_base/module_device/output_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,13 @@ void output_device_info(std::ostream &output)
int local_rank = get_node_rank_with_mpi_shared(MPI_COMM_WORLD);

// Get local hardware info
int local_gpu_count = local_rank == 0 ? get_device_num("gpu") : 0;
int local_gpu_count = 0;
#if defined(__CUDA) || defined(__ROCM)
if(PARAM.inp.device == "gpu" && local_rank == 0)
{
local_gpu_count = get_device_num("gpu");
}
#endif
int local_cpu_sockets = local_rank == 0 ? get_device_num("cpu") : 0;

// Prepare vectors to gather data from all ranks
Expand All @@ -133,7 +139,13 @@ void output_device_info(std::ostream &output)

// Get device model names (from rank 0 node)
std::string cpu_name = get_device_name("cpu");
std::string gpu_name = get_device_name("gpu");
std::string gpu_name;
#if defined(__CUDA) || defined(__ROCM)
if(PARAM.inp.device == "gpu" && total_gpus > 0)
{
gpu_name = get_device_name("gpu");
}
#endif

// Output all collected information
output << " RUNNING WITH DEVICE : " << "CPU" << " / "
Expand Down
Loading