@@ -74,7 +74,7 @@ __global__ void reduction_kernel(const scalar_t *decays,
7474 scalar_t *h_storage = &_h_storage[blockIdx .x * 33 * n_dims];
7575
7676 int2 start_stop =
77- compute_warp_start_stop (blockIdx .x , warp , gridDim .x , n_steps);
77+ compute_warp_start_stop (blockIdx .x , lane , gridDim .x , n_steps);
7878 int warp_start = start_stop.x ;
7979 int warp_stop = start_stop.y ;
8080
@@ -84,22 +84,22 @@ __global__ void reduction_kernel(const scalar_t *decays,
8484 * from warp_start to warp_stop (including initial state) at index
8585 * (feature_idx, warp, block).
8686 */
87- for (int i = lane ; i < n_dims; i += 32 ) {
87+ for (int i = warp ; i < n_dims; i += CEIL_DIV ( blockDim . x , 32 ) ) {
8888 scalar_t cum_decay = static_cast <scalar_t >(1.0 );
8989 scalar_t h = static_cast <scalar_t >(0.0 );
90- if (blockIdx .x == 0 && warp == 0 && initial_state != NULL ) {
90+ if (blockIdx .x == 0 && lane == 0 && initial_state != NULL ) {
9191 h = initial_state[i];
9292 }
9393
9494 for (int t = warp_start; t < warp_stop; t++) {
95- cum_decay *= decays[i + t * n_dims ];
96- h = decays[i + t * n_dims ] * h + impulses[i + t * n_dims ];
95+ cum_decay *= decays[i * n_steps + t ];
96+ h = decays[i * n_steps + t ] * h + impulses[i * n_steps + t ];
9797 }
9898
9999 // TODO: store into shared memory, work in shared memory sized blocks
100100 // store into global memory
101- decay_storage[i + warp * n_dims] = cum_decay;
102- h_storage[i + warp * n_dims] = h;
101+ decay_storage[i + lane * n_dims] = cum_decay;
102+ h_storage[i + lane * n_dims] = h;
103103 }
104104
105105 __syncthreads ();
@@ -112,7 +112,7 @@ __global__ void reduction_kernel(const scalar_t *decays,
112112 */
113113 // TODO: parallel reduction (or scan). Need to worry about changing the warp
114114 // reduction values (as I use them again later)
115- for (int i = lane + 32 * warp ; i < n_dims; i += blockDim .x ) {
115+ for (int i = threadIdx . x ; i < n_dims; i += blockDim .x ) {
116116 scalar_t cum_decay = static_cast <scalar_t >(1.0 );
117117 scalar_t h = static_cast <scalar_t >(0.0 );
118118 for (int t = 0 ; t < 32 ; t++) {
@@ -176,7 +176,7 @@ __global__ void warp_scan_kernel(const scalar_t *decays,
176176 * condition) up to and including the indexed warp and block.
177177 */
178178 // TODO: parallel scan
179- for (int i = lane + 32 * warp ; i < n_dims; i += blockDim .x ) {
179+ for (int i = threadIdx . x ; i < n_dims; i += blockDim .x ) {
180180 for (int t = 0 ; t < 32 ; t++) {
181181 if (t == 0 && blockIdx .x == 0 ) {
182182 // the reduction over warp 0 (including initial condition) is
@@ -195,7 +195,7 @@ __global__ void warp_scan_kernel(const scalar_t *decays,
195195 __syncthreads ();
196196
197197 int2 start_stop =
198- compute_warp_start_stop (blockIdx .x , warp , gridDim .x , n_steps);
198+ compute_warp_start_stop (blockIdx .x , lane , gridDim .x , n_steps);
199199 int warp_start = start_stop.x ;
200200 int warp_stop = start_stop.y ;
201201
@@ -205,19 +205,19 @@ __global__ void warp_scan_kernel(const scalar_t *decays,
205205 * state (either from the "initial_state" or the storage arrays) and then
206206 * writes to output for indices warp_start up to warp_stop.
207207 */
208- for (int i = lane ; i < n_dims; i += 32 ) {
208+ for (int i = warp ; i < n_dims; i += CEIL_DIV ( blockDim . x , 32 ) ) {
209209 scalar_t h = static_cast <scalar_t >(0.0 );
210- if (blockIdx .x == 0 && warp == 0 ) {
210+ if (blockIdx .x == 0 && lane == 0 ) {
211211 if (initial_state != NULL ) {
212212 h = initial_state[i];
213213 }
214214 } else {
215- h = h_storage[i + (warp - 1 ) * n_dims + blockIdx .x * 33 * n_dims];
215+ h = h_storage[i + (lane - 1 ) * n_dims + blockIdx .x * 33 * n_dims];
216216 }
217217
218218 for (int t = warp_start; t < warp_stop; t++) {
219- h = decays[i + t * n_dims ] * h + impulses[i + t * n_dims ];
220- out[i + t * n_dims ] = h;
219+ h = decays[i * n_steps + t ] * h + impulses[i * n_steps + t ];
220+ out[i * n_steps + t ] = h;
221221 }
222222 }
223223}
@@ -233,13 +233,10 @@ template <typename scalar_t>
233233void compute_linear_recurrence (const scalar_t *decays, const scalar_t *impulses,
234234 const scalar_t *initial_state, scalar_t *out,
235235 int n_dims, int n_steps) {
236- // TODO: query
237- int n_SMs = 15 ;
238- int n_blocks_per_sm = 2 ;
239-
240236 // we want at least 32 elements per block, but no reason to run
241237 // with more than the maximum number of concurrent blocks
242- int n_blocks = min (CEIL_DIV (n_steps, 32 ), n_SMs * n_blocks_per_sm);
238+ // NOTE: 128 is decided empirically.
239+ int n_blocks = min (CEIL_DIV (n_steps, 32 ), 128 );
243240
244241 // TODO: make user pass in working memory? This allows integration
245242 // with CNMeM (used by Theano)
@@ -273,8 +270,8 @@ at::Tensor scan_cuda_wrapper(const at::Tensor &input, const at::Tensor &weights,
273270 TORCH_CHECK (weights.scalar_type () == input.scalar_type (),
274271 " Weights must have the same scalar type as input" );
275272
276- auto input_contiguous = input.transpose ( 0 , 1 ). contiguous ();
277- auto weights_contiguous = weights.transpose ( 0 , 1 ). contiguous ();
273+ auto input_contiguous = input.contiguous ();
274+ auto weights_contiguous = weights.contiguous ();
278275 auto output = at::empty_like (input_contiguous);
279276
280277 const at::cuda::OptionalCUDAGuard device_guard (device_of (input));
@@ -285,10 +282,10 @@ at::Tensor scan_cuda_wrapper(const at::Tensor &input, const at::Tensor &weights,
285282 weights_contiguous.const_data_ptr <scalar_t >(),
286283 input_contiguous.const_data_ptr <scalar_t >(),
287284 initials.const_data_ptr <scalar_t >(),
288- output.mutable_data_ptr <scalar_t >(), input_contiguous.size (1 ),
289- input_contiguous.size (0 ));
285+ output.mutable_data_ptr <scalar_t >(), input_contiguous.size (0 ),
286+ input_contiguous.size (1 ));
290287 });
291- return output.transpose ( 0 , 1 ). contiguous ();
288+ return output.contiguous ();
292289}
293290
294291TORCH_LIBRARY_IMPL (torchlpc, CUDA, m) { m.impl (" scan" , &scan_cuda_wrapper); }
0 commit comments