Skip to content

Commit a98aae9

Browse files
committed
refactor: use channel-first format, swap the role of lane and warp to run faster
1 parent bf9acfe commit a98aae9

File tree

1 file changed

+22
-25
lines changed

1 file changed

+22
-25
lines changed

torchlpc/csrc/cuda/linear_recurrence.cu

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ __global__ void reduction_kernel(const scalar_t *decays,
7474
scalar_t *h_storage = &_h_storage[blockIdx.x * 33 * n_dims];
7575

7676
int2 start_stop =
77-
compute_warp_start_stop(blockIdx.x, warp, gridDim.x, n_steps);
77+
compute_warp_start_stop(blockIdx.x, lane, gridDim.x, n_steps);
7878
int warp_start = start_stop.x;
7979
int warp_stop = start_stop.y;
8080

@@ -84,22 +84,22 @@ __global__ void reduction_kernel(const scalar_t *decays,
8484
* from warp_start to warp_stop (including initial state) at index
8585
* (feature_idx, warp, block).
8686
*/
87-
for (int i = lane; i < n_dims; i += 32) {
87+
for (int i = warp; i < n_dims; i += CEIL_DIV(blockDim.x, 32)) {
8888
scalar_t cum_decay = static_cast<scalar_t>(1.0);
8989
scalar_t h = static_cast<scalar_t>(0.0);
90-
if (blockIdx.x == 0 && warp == 0 && initial_state != NULL) {
90+
if (blockIdx.x == 0 && lane == 0 && initial_state != NULL) {
9191
h = initial_state[i];
9292
}
9393

9494
for (int t = warp_start; t < warp_stop; t++) {
95-
cum_decay *= decays[i + t * n_dims];
96-
h = decays[i + t * n_dims] * h + impulses[i + t * n_dims];
95+
cum_decay *= decays[i * n_steps + t];
96+
h = decays[i * n_steps + t] * h + impulses[i * n_steps + t];
9797
}
9898

9999
// TODO: store into shared memory, work in shared memory sized blocks
100100
// store into global memory
101-
decay_storage[i + warp * n_dims] = cum_decay;
102-
h_storage[i + warp * n_dims] = h;
101+
decay_storage[i + lane * n_dims] = cum_decay;
102+
h_storage[i + lane * n_dims] = h;
103103
}
104104

105105
__syncthreads();
@@ -112,7 +112,7 @@ __global__ void reduction_kernel(const scalar_t *decays,
112112
*/
113113
// TODO: parallel reduction (or scan). Need to worry about changing the warp
114114
// reduction values (as I use them again later)
115-
for (int i = lane + 32 * warp; i < n_dims; i += blockDim.x) {
115+
for (int i = threadIdx.x; i < n_dims; i += blockDim.x) {
116116
scalar_t cum_decay = static_cast<scalar_t>(1.0);
117117
scalar_t h = static_cast<scalar_t>(0.0);
118118
for (int t = 0; t < 32; t++) {
@@ -176,7 +176,7 @@ __global__ void warp_scan_kernel(const scalar_t *decays,
176176
* condition) up to and including the indexed warp and block.
177177
*/
178178
// TODO: parallel scan
179-
for (int i = lane + 32 * warp; i < n_dims; i += blockDim.x) {
179+
for (int i = threadIdx.x; i < n_dims; i += blockDim.x) {
180180
for (int t = 0; t < 32; t++) {
181181
if (t == 0 && blockIdx.x == 0) {
182182
// the reduction over warp 0 (including initial condition) is
@@ -195,7 +195,7 @@ __global__ void warp_scan_kernel(const scalar_t *decays,
195195
__syncthreads();
196196

197197
int2 start_stop =
198-
compute_warp_start_stop(blockIdx.x, warp, gridDim.x, n_steps);
198+
compute_warp_start_stop(blockIdx.x, lane, gridDim.x, n_steps);
199199
int warp_start = start_stop.x;
200200
int warp_stop = start_stop.y;
201201

@@ -205,19 +205,19 @@ __global__ void warp_scan_kernel(const scalar_t *decays,
205205
* state (either from the "initial_state" or the storage arrays) and then
206206
* writes to output for indices warp_start up to warp_stop.
207207
*/
208-
for (int i = lane; i < n_dims; i += 32) {
208+
for (int i = warp; i < n_dims; i += CEIL_DIV(blockDim.x, 32)) {
209209
scalar_t h = static_cast<scalar_t>(0.0);
210-
if (blockIdx.x == 0 && warp == 0) {
210+
if (blockIdx.x == 0 && lane == 0) {
211211
if (initial_state != NULL) {
212212
h = initial_state[i];
213213
}
214214
} else {
215-
h = h_storage[i + (warp - 1) * n_dims + blockIdx.x * 33 * n_dims];
215+
h = h_storage[i + (lane - 1) * n_dims + blockIdx.x * 33 * n_dims];
216216
}
217217

218218
for (int t = warp_start; t < warp_stop; t++) {
219-
h = decays[i + t * n_dims] * h + impulses[i + t * n_dims];
220-
out[i + t * n_dims] = h;
219+
h = decays[i * n_steps + t] * h + impulses[i * n_steps + t];
220+
out[i * n_steps + t] = h;
221221
}
222222
}
223223
}
@@ -233,13 +233,10 @@ template <typename scalar_t>
233233
void compute_linear_recurrence(const scalar_t *decays, const scalar_t *impulses,
234234
const scalar_t *initial_state, scalar_t *out,
235235
int n_dims, int n_steps) {
236-
// TODO: query
237-
int n_SMs = 15;
238-
int n_blocks_per_sm = 2;
239-
240236
// we want at least 32 elements per block, but no reason to run
241237
// with more than the maximum number of concurrent blocks
242-
int n_blocks = min(CEIL_DIV(n_steps, 32), n_SMs * n_blocks_per_sm);
238+
// NOTE: 128 is decided empirically.
239+
int n_blocks = min(CEIL_DIV(n_steps, 32), 128);
243240

244241
// TODO: make user pass in working memory? This allows integration
245242
// with CNMeM (used by Theano)
@@ -273,8 +270,8 @@ at::Tensor scan_cuda_wrapper(const at::Tensor &input, const at::Tensor &weights,
273270
TORCH_CHECK(weights.scalar_type() == input.scalar_type(),
274271
"Weights must have the same scalar type as input");
275272

276-
auto input_contiguous = input.transpose(0, 1).contiguous();
277-
auto weights_contiguous = weights.transpose(0, 1).contiguous();
273+
auto input_contiguous = input.contiguous();
274+
auto weights_contiguous = weights.contiguous();
278275
auto output = at::empty_like(input_contiguous);
279276

280277
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
@@ -285,10 +282,10 @@ at::Tensor scan_cuda_wrapper(const at::Tensor &input, const at::Tensor &weights,
285282
weights_contiguous.const_data_ptr<scalar_t>(),
286283
input_contiguous.const_data_ptr<scalar_t>(),
287284
initials.const_data_ptr<scalar_t>(),
288-
output.mutable_data_ptr<scalar_t>(), input_contiguous.size(1),
289-
input_contiguous.size(0));
285+
output.mutable_data_ptr<scalar_t>(), input_contiguous.size(0),
286+
input_contiguous.size(1));
290287
});
291-
return output.transpose(0, 1).contiguous();
288+
return output.contiguous();
292289
}
293290

294291
TORCH_LIBRARY_IMPL(torchlpc, CUDA, m) { m.impl("scan", &scan_cuda_wrapper); }

0 commit comments

Comments
 (0)