@@ -74,10 +74,10 @@ template<typename T>
7474static __global__ void cumsum_kernel (
7575 const T * src, T * dst,
7676 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
77- const int64_t nb00 , const int64_t nb01 , const int64_t nb02 , const int64_t nb03 ,
78- const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3 ) {
77+ const int64_t s00 , const int64_t s01 , const int64_t s02 , const int64_t s03 ,
78+ const int64_t s0, const int64_t s1, const int64_t s2, const int64_t s3 ) {
7979
80- GGML_UNUSED_VARS (nb00, nb0 );
80+ GGML_UNUSED_VARS (s00, s0 );
8181
8282 const int tid = threadIdx .x ;
8383 constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
@@ -104,8 +104,8 @@ static __global__ void cumsum_kernel(
104104 return ;
105105 }
106106
107- const T * src_row = src + i1 * nb01 + i2 * nb02 + i3 * nb03 ;
108- T * dst_row = dst + i1 * nb1 + i2 * nb2 + i3 * nb3 ;
107+ const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03 ;
108+ T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3 ;
109109
110110 for (int64_t start = 0 ; start < ne00; start += blockDim .x ) {
111111 int64_t idx = start + tid;
@@ -153,22 +153,23 @@ template<typename T>
153153static void cumsum_cuda (
154154 const T * src, T * dst,
155155 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
156- const int64_t nb00 , const int64_t nb01 , const int64_t nb02 , const int64_t nb03 ,
157- const int64_t nb0 , const int64_t nb1 , const int64_t nb2 , const int64_t nb3 ,
156+ const int64_t s00 , const int64_t s01 , const int64_t s02 , const int64_t s03 ,
157+ const int64_t s0 , const int64_t s1 , const int64_t s2 , const int64_t s3 ,
158158 cudaStream_t stream) {
159159
160160 const size_t type_size = sizeof (T);
161161 bool use_cub = false ;
162162#ifdef GGML_CUDA_USE_CUB
163163 // Check if we can use CUB (data must be contiguous along innermost dimension)
164- const bool is_contiguous = (nb00 == type_size) && (nb0 == type_size);
164+ const bool is_contiguous = (s00 == type_size) && (s0 == type_size);
165165
166166 if (is_contiguous) {
167167 use_cub = true ;
168168 }
169169#endif // GGML_CUDA_USE_CUB
170170 dim3 grid_dims (ne01, ne02, ne03);
171- const int warp_size = ggml_cuda_get_physical_warp_size_host ();
171+ const auto &info = ggml_cuda_info ().devices [ggml_cuda_get_device ()];
172+ const int warp_size = info.warp_size ;
172173 const int num_warps = (ne00 + warp_size - 1 ) / warp_size;
173174 int block_size = num_warps * warp_size;
174175 block_size = std::min (block_size, CUDA_CUMSUM_BLOCK_SIZE);
@@ -180,15 +181,15 @@ static void cumsum_cuda(
180181 cumsum_cub_kernel<T, CUDA_CUMSUM_BLOCK_SIZE><<<grid_dims, CUDA_CUMSUM_BLOCK_SIZE, 0 , stream>>> (
181182 src, dst,
182183 ne00, ne01, ne02, ne03,
183- nb01 / type_size, nb02 / type_size, nb03 / type_size,
184- nb1 / type_size, nb2 / type_size, nb3 / type_size
184+ s01 / type_size, s02 / type_size, s03 / type_size,
185+ s1 / type_size, s2 / type_size, s3 / type_size
185186 );
186187 } else {
187188 cumsum_kernel<<<grid_dims, block_dims, shmem_size, stream>>> (
188189 src, dst,
189190 ne00, ne01, ne02, ne03,
190- nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
191- nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
191+ s00 / type_size, s01 / type_size, s02 / type_size, s03 / type_size,
192+ s0 / type_size, s1 / type_size, s2 / type_size, s3 / type_size
192193 );
193194 }
194195}
0 commit comments