11#include < assert.h>
2+ #include < c10/cuda/CUDAException.h>
3+ #include < c10/cuda/CUDAGuard.h>
24#include < stdio.h>
5+ #include < torch/script.h>
6+ #include < torch/torch.h>
37
48#define CEIL_DIV (x, y ) ((x + y - 1 ) / y)
59
@@ -57,14 +61,17 @@ __device__ int2 compute_warp_start_stop(int block_idx, int warp_idx,
5761// decay storage, h_storage:
5862// each a n_dims x 33 x n_blocks matrix on GPU with 33rd column for block
5963// reduction
60- __global__ void reduction_kernel (float *decays, float *impulses,
61- float *initial_state, float *_decay_storage,
62- float *_h_storage, int n_dims, int n_steps) {
64+ template <typename scalar_t >
65+ __global__ void reduction_kernel (const scalar_t *decays,
66+ const scalar_t *impulses,
67+ const scalar_t *initial_state,
68+ scalar_t *_decay_storage, scalar_t *_h_storage,
69+ int n_dims, int n_steps) {
6370 int warp = threadIdx .x / 32 ;
6471 int lane = threadIdx .x % 32 ;
6572
66- float *decay_storage = &_decay_storage[blockIdx .x * 33 * n_dims];
67- float *h_storage = &_h_storage[blockIdx .x * 33 * n_dims];
73+ scalar_t *decay_storage = &_decay_storage[blockIdx .x * 33 * n_dims];
74+ scalar_t *h_storage = &_h_storage[blockIdx .x * 33 * n_dims];
6875
6976 int2 start_stop =
7077 compute_warp_start_stop (blockIdx .x , warp, gridDim .x , n_steps);
@@ -78,8 +85,8 @@ __global__ void reduction_kernel(float *decays, float *impulses,
7885 * (feature_idx, warp, block).
7986 */
8087 for (int i = lane; i < n_dims; i += 32 ) {
81- float cum_decay = 1.0 ;
82- float h = 0.0 ;
88+ scalar_t cum_decay = static_cast < scalar_t >( 1.0 ) ;
89+ scalar_t h = static_cast < scalar_t >( 0.0 ) ;
8390 if (blockIdx .x == 0 && warp == 0 && initial_state != NULL ) {
8491 h = initial_state[i];
8592 }
@@ -106,8 +113,8 @@ __global__ void reduction_kernel(float *decays, float *impulses,
106113 // TODO: parallel reduction (or scan). Need to worry about changing the warp
107114 // reduction values (as I use them again later)
108115 for (int i = lane + 32 * warp; i < n_dims; i += blockDim .x ) {
109- float cum_decay = 1.0 ;
110- float h = 0.0 ;
116+ scalar_t cum_decay = static_cast < scalar_t >( 1.0 ) ;
117+ scalar_t h = static_cast < scalar_t >( 0.0 ) ;
111118 for (int t = 0 ; t < 32 ; t++) {
112119 cum_decay *= decay_storage[i + t * n_dims];
113120 h = decay_storage[i + t * n_dims] * h + h_storage[i + t * n_dims];
@@ -117,7 +124,8 @@ __global__ void reduction_kernel(float *decays, float *impulses,
117124 }
118125}
119126
120- __global__ void block_scan_kernel (float *decay_storage, float *h_storage,
127+ template <typename scalar_t >
128+ __global__ void block_scan_kernel (scalar_t *decay_storage, scalar_t *h_storage,
121129 int n_dims, int n_blocks) {
122130 /*
123131 * Scan over blocks.
@@ -142,9 +150,11 @@ __global__ void block_scan_kernel(float *decay_storage, float *h_storage,
142150 }
143151}
144152
145- __global__ void warp_scan_kernel (float *decays, float *impulses,
146- float *initial_state, float *out,
147- float *decay_storage, float *h_storage,
153+ template <typename scalar_t >
154+ __global__ void warp_scan_kernel (const scalar_t *decays,
155+ const scalar_t *impulses,
156+ const scalar_t *initial_state, scalar_t *out,
157+ scalar_t *decay_storage, scalar_t *h_storage,
148158 int n_dims, int n_steps) {
149159 int warp = threadIdx .x / 32 ;
150160 int lane = threadIdx .x % 32 ;
@@ -196,7 +206,7 @@ __global__ void warp_scan_kernel(float *decays, float *impulses,
196206 * writes to output for indices warp_start up to warp_stop.
197207 */
198208 for (int i = lane; i < n_dims; i += 32 ) {
199- float h = 0.0 ;
209+ scalar_t h = static_cast < scalar_t >( 0.0 ) ;
200210 if (blockIdx .x == 0 && warp == 0 ) {
201211 if (initial_state != NULL ) {
202212 h = initial_state[i];
@@ -212,34 +222,17 @@ __global__ void warp_scan_kernel(float *decays, float *impulses,
212222 }
213223}
214224
215- __global__ void serial_linear_recurrence (float *decays, float *impulses,
216- float *initial_state, float *out,
217- int n_dims, int n_steps) {
218- // computes h_t = lambda_t h{t-1} + x_t
219-
220- for (int dim_idx = threadIdx .x + blockIdx .x * blockDim .x ; dim_idx < n_dims;
221- dim_idx += blockDim .x * gridDim .x ) {
222- float val = initial_state[dim_idx];
223-
224- for (int step = 0 ; step < n_steps; step++) {
225- int idx = dim_idx + step * n_dims;
226- val = decays[idx] * val + impulses[idx];
227- out[idx] = val;
228- }
229- }
230- }
231-
232- extern " C" {
233225/*
234226 * This is the main method for the prefix sum kernels.
235227 * decays, impulses, out:
236228 * each a n_dims x n_steps column major matrix located on GPU
237229 * initial_state:
238230 * array of size n_dims located on GPU
239231 */
240- void compute_linear_recurrence (float *decays, float *impulses,
241- float *initial_state, float *out, int n_dims,
242- int n_steps) {
232+ template <typename scalar_t >
233+ void compute_linear_recurrence (const scalar_t *decays, const scalar_t *impulses,
234+ const scalar_t *initial_state, scalar_t *out,
235+ int n_dims, int n_steps) {
243236 // TODO: query
244237 int n_SMs = 15 ;
245238 int n_blocks_per_sm = 2 ;
@@ -251,10 +244,10 @@ void compute_linear_recurrence(float *decays, float *impulses,
251244 // TODO: make user pass in working memory? This allows integration
252245 // with CNMeM (used by Theano)
253246 int reduction_mem_sz = 2 * n_blocks * 33 * n_dims * sizeof (float );
254- float *d_reduction_mem;
247+ scalar_t *d_reduction_mem;
255248 gpuErrChk (cudaMalloc (&d_reduction_mem, reduction_mem_sz));
256- float *d_decay_storage = &d_reduction_mem[0 * n_blocks * 33 * n_dims];
257- float *d_h_storage = &d_reduction_mem[1 * n_blocks * 33 * n_dims];
249+ scalar_t *d_decay_storage = &d_reduction_mem[0 * n_blocks * 33 * n_dims];
250+ scalar_t *d_h_storage = &d_reduction_mem[1 * n_blocks * 33 * n_dims];
258251
259252 // TODO: run kernels on non-default stream?
260253 reduction_kernel<<<n_blocks, 1024 >>> (decays, impulses, initial_state,
@@ -271,53 +264,31 @@ void compute_linear_recurrence(float *decays, float *impulses,
271264 gpuErrChk (cudaFree (d_reduction_mem));
272265}
273266
274- void compute_serial_linear_recurrence (float *decays, float *impulses,
275- float *initial_state, float *out,
276- int n_dims, int n_steps) {
277- // TODO: query
278- int n_SMs = 15 ;
279- int n_blocks_per_sm = 2 ;
280-
281- int n_blocks = n_SMs * n_blocks_per_sm;
282- serial_linear_recurrence<<<n_blocks, 1024 >>> (
283- decays, impulses, initial_state, out, n_dims, n_steps);
284- }
267+ at::Tensor scan_cuda_wrapper (const at::Tensor &input, const at::Tensor &weights,
268+ const at::Tensor &initials) {
269+ TORCH_CHECK (input.is_floating_point () || input.is_complex (),
270+ " Input must be floating point or complex" );
271+ TORCH_CHECK (initials.scalar_type () == input.scalar_type (),
272+ " Initials must have the same scalar type as input" );
273+ TORCH_CHECK (weights.scalar_type () == input.scalar_type (),
274+ " Weights must have the same scalar type as input" );
275+
276+ auto input_contiguous = input.transpose (0 , 1 ).contiguous ();
277+ auto weights_contiguous = weights.transpose (0 , 1 ).contiguous ();
278+ auto output = at::empty_like (input_contiguous);
279+
280+ const at::cuda::OptionalCUDAGuard device_guard (device_of (input));
281+
282+ AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES (
283+ input.scalar_type (), " compute_linear_recurrence" , [&] {
284+ compute_linear_recurrence<scalar_t >(
285+ weights_contiguous.const_data_ptr <scalar_t >(),
286+ input_contiguous.const_data_ptr <scalar_t >(),
287+ initials.const_data_ptr <scalar_t >(),
288+ output.mutable_data_ptr <scalar_t >(), input.size (1 ),
289+ input.size (0 ));
290+ });
291+ return output.transpose (0 , 1 ).contiguous ();
285292}
286293
287- void test () {
288- int n_dims = 100 ;
289- int n_steps = 1000000 ;
290- int n_elements = n_dims * n_steps;
291-
292- float *decays = (float *)calloc (n_elements, sizeof (float ));
293- for (int i = 0 ; i < n_elements; i++) {
294- decays[i] = .999 ;
295- }
296- float *d_decays;
297- gpuErrChk (cudaMalloc (&d_decays, n_elements * sizeof (float )));
298- gpuErrChk (cudaMemcpy (d_decays, decays, n_elements * sizeof (float ),
299- cudaMemcpyHostToDevice));
300-
301- float *impulses = (float *)calloc (n_elements, sizeof (float ));
302- for (int i = 0 ; i < n_dims; i++) {
303- impulses[i + 0 * n_dims] = 2.0 ;
304- }
305- float *d_impulses;
306- gpuErrChk (cudaMalloc (&d_impulses, n_elements * sizeof (float )));
307- gpuErrChk (cudaMemcpy (d_impulses, impulses, n_elements * sizeof (float ),
308- cudaMemcpyHostToDevice));
309-
310- float *out = (float *)calloc (n_elements, sizeof (float ));
311- float *d_out;
312- gpuErrChk (cudaMalloc (&d_out, n_elements * sizeof (float )));
313- gpuErrChk (cudaMemset (d_out, 0 , n_elements * sizeof (float )));
314-
315- compute_linear_recurrence (d_decays, d_impulses, NULL , d_out, n_dims,
316- n_steps);
317- gpuErrChk (cudaMemcpy (out, d_out, n_elements * sizeof (float ),
318- cudaMemcpyDeviceToHost));
319-
320- gpuErrChk (cudaFree (d_decays));
321- gpuErrChk (cudaFree (d_impulses));
322- gpuErrChk (cudaFree (d_out));
323- }
294+ TORCH_LIBRARY_IMPL (torchlpc, CUDA, m) { m.impl (" scan" , &scan_cuda_wrapper); }
0 commit comments