Skip to content

Commit 2af2d47

Browse files
committed
Support different CUDA versions in one single cuda_compat.h
1 parent be03870 commit 2af2d47

File tree

3 files changed

+43
-17
lines changed

3 files changed

+43
-17
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/**
2+
* @file cuda_compat.h
3+
* @brief Compatibility layer for CUDA and NVTX headers across different CUDA Toolkit versions.
4+
*
5+
* This header abstracts the differences in NVTX (NVIDIA Tools Extension) header locations
6+
* between CUDA Toolkit versions.
7+
*
8+
* @note Depends on the CUDA_VERSION macro defined in <cuda.h>.
9+
*
10+
*/
11+
12+
#ifndef CUDA_COMPAT_H_
13+
#define CUDA_COMPAT_H_
14+
15+
#include <cuda.h> // defines CUDA_VERSION
16+
17+
// NVTX header for CUDA versions prior to 12.9 vs. 12.9+
18+
// This block ensures the correct NVTX header path is used based on CUDA_VERSION.
19+
// - For CUDA Toolkit < 12.9, the legacy header "nvToolsExt.h" is included.
20+
// - For CUDA Toolkit >= 12.9, the modern header "nvtx3/nvToolsExt.h" is included,
21+
// and NVTX v2 is removed from 12.9.
22+
// This allows NVTX profiling APIs (e.g. nvtxRangePush) to be used consistently
23+
// across different CUDA versions.
24+
// See:
25+
// https://docs.nvidia.com/cuda/archive/12.9.0/cuda-toolkit-release-notes/index.html#id4
26+
#if defined(__CUDA) && defined(__USE_NVTX)
27+
#if CUDA_VERSION < 12090
28+
#include "nvToolsExt.h"
29+
#else
30+
#include "nvtx3/nvToolsExt.h"
31+
#endif
32+
#endif
33+
34+
#endif // CUDA_COMPAT_H_

source/source_base/timer.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,7 @@
1515
#include "source_base/formatter.h"
1616

1717
#if defined(__CUDA) && defined(__USE_NVTX)
18-
#if CUDA_VERSION < 12090
19-
#include "nvToolsExt.h"
20-
#else
21-
#include "nvtx3/nvToolsExt.h"
22-
#endif
18+
#include "source_base/module_device/cuda_compat.h"
2319
#include "source_io/module_parameter/parameter.h"
2420
#endif
2521

source/source_hsolver/kernels/cuda/diag_cusolver.cuh

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,7 @@
33
#include <cuda.h>
44
#include <complex>
55

6-
#if CUDA_VERSION < 12090
7-
#include "nvToolsExt.h"
8-
#else
9-
#include "nvtx3/nvToolsExt.h"
10-
#endif
6+
// #include "source_base/module_device/cuda_compat.h"
117

128
#include <cuda_runtime.h>
139
#include <cusolverDn.h>
@@ -39,7 +35,7 @@ class Diag_Cusolver_gvd{
3935
double *d_A = nullptr;
4036
double *d_B = nullptr;
4137
double *d_work = nullptr;
42-
38+
4339
cuDoubleComplex *d_A2 = nullptr;
4440
cuDoubleComplex *d_B2 = nullptr;
4541
cuDoubleComplex *d_work2 = nullptr;
@@ -54,7 +50,7 @@ class Diag_Cusolver_gvd{
5450
// - init_double : initializing relevant double type data structures and gpu apis' handle and memory
5551
// - init_complex : initializing relevant complex type data structures and gpu apis' handle and memory
5652
// Input Parameters
57-
// N: the dimension of the matrix
53+
// N: the dimension of the matrix
5854
void init_double(int N);
5955
void init_complex(int N);
6056

@@ -70,17 +66,17 @@ public:
7066
// - Dngvd_double : dense double type matrix
7167
// - Dngvd_complex : dense complex type matrix
7268
// Input Parameters
73-
// N: the number of rows of the matrix
74-
// M: the number of cols of the matrix
75-
// A: the hermitian matrix A in A x=lambda B (column major)
76-
// B: the SPD matrix B in A x=lambda B (column major)
69+
// N: the number of rows of the matrix
70+
// M: the number of cols of the matrix
71+
// A: the hermitian matrix A in A x=lambda B (column major)
72+
// B: the SPD matrix B in A x=lambda B (column major)
7773
// Output Parameter
7874
// W: generalized eigenvalues
7975
// V: generalized eigenvectors (column major)
8076

8177
void Dngvd_double(int N, int M, double *A, double *B, double *W, double *V);
8278
void Dngvd_complex(int N, int M, std::complex<double> *A, std::complex<double> *B, double *W, std::complex<double> *V);
83-
79+
8480
void Dngvd(int N, int M, double *A, double *B, double *W, double *V)
8581
{
8682
return Dngvd_double(N, M, A, B, W, V);

0 commit comments

Comments
 (0)