Skip to content

Commit 9164dcb

Browse files
committed
refactor linear recurrence code
1 parent ad77bbb commit 9164dcb

File tree

3 files changed

+57
-96
lines changed

3 files changed

+57
-96
lines changed
File renamed without changes.

torchlpc/csrc/linear_recurrent_net/linear_recurrence.cu renamed to torchlpc/csrc/cuda/linear_recurrence.cu

Lines changed: 57 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
#include <assert.h>
2+
#include <c10/cuda/CUDAException.h>
3+
#include <c10/cuda/CUDAGuard.h>
24
#include <stdio.h>
5+
#include <torch/script.h>
6+
#include <torch/torch.h>
37

48
#define CEIL_DIV(x, y) ((x + y - 1) / y)
59

@@ -57,14 +61,17 @@ __device__ int2 compute_warp_start_stop(int block_idx, int warp_idx,
5761
// decay storage, h_storage:
5862
// each a n_dims x 33 x n_blocks matrix on GPU with 33rd column for block
5963
// reduction
60-
__global__ void reduction_kernel(float *decays, float *impulses,
61-
float *initial_state, float *_decay_storage,
62-
float *_h_storage, int n_dims, int n_steps) {
64+
template <typename scalar_t>
65+
__global__ void reduction_kernel(const scalar_t *decays,
66+
const scalar_t *impulses,
67+
const scalar_t *initial_state,
68+
scalar_t *_decay_storage, scalar_t *_h_storage,
69+
int n_dims, int n_steps) {
6370
int warp = threadIdx.x / 32;
6471
int lane = threadIdx.x % 32;
6572

66-
float *decay_storage = &_decay_storage[blockIdx.x * 33 * n_dims];
67-
float *h_storage = &_h_storage[blockIdx.x * 33 * n_dims];
73+
scalar_t *decay_storage = &_decay_storage[blockIdx.x * 33 * n_dims];
74+
scalar_t *h_storage = &_h_storage[blockIdx.x * 33 * n_dims];
6875

6976
int2 start_stop =
7077
compute_warp_start_stop(blockIdx.x, warp, gridDim.x, n_steps);
@@ -78,8 +85,8 @@ __global__ void reduction_kernel(float *decays, float *impulses,
7885
* (feature_idx, warp, block).
7986
*/
8087
for (int i = lane; i < n_dims; i += 32) {
81-
float cum_decay = 1.0;
82-
float h = 0.0;
88+
scalar_t cum_decay = static_cast<scalar_t>(1.0);
89+
scalar_t h = static_cast<scalar_t>(0.0);
8390
if (blockIdx.x == 0 && warp == 0 && initial_state != NULL) {
8491
h = initial_state[i];
8592
}
@@ -106,8 +113,8 @@ __global__ void reduction_kernel(float *decays, float *impulses,
106113
// TODO: parallel reduction (or scan). Need to worry about changing the warp
107114
// reduction values (as I use them again later)
108115
for (int i = lane + 32 * warp; i < n_dims; i += blockDim.x) {
109-
float cum_decay = 1.0;
110-
float h = 0.0;
116+
scalar_t cum_decay = static_cast<scalar_t>(1.0);
117+
scalar_t h = static_cast<scalar_t>(0.0);
111118
for (int t = 0; t < 32; t++) {
112119
cum_decay *= decay_storage[i + t * n_dims];
113120
h = decay_storage[i + t * n_dims] * h + h_storage[i + t * n_dims];
@@ -117,7 +124,8 @@ __global__ void reduction_kernel(float *decays, float *impulses,
117124
}
118125
}
119126

120-
__global__ void block_scan_kernel(float *decay_storage, float *h_storage,
127+
template <typename scalar_t>
128+
__global__ void block_scan_kernel(scalar_t *decay_storage, scalar_t *h_storage,
121129
int n_dims, int n_blocks) {
122130
/*
123131
* Scan over blocks.
@@ -142,9 +150,11 @@ __global__ void block_scan_kernel(float *decay_storage, float *h_storage,
142150
}
143151
}
144152

145-
__global__ void warp_scan_kernel(float *decays, float *impulses,
146-
float *initial_state, float *out,
147-
float *decay_storage, float *h_storage,
153+
template <typename scalar_t>
154+
__global__ void warp_scan_kernel(const scalar_t *decays,
155+
const scalar_t *impulses,
156+
const scalar_t *initial_state, scalar_t *out,
157+
scalar_t *decay_storage, scalar_t *h_storage,
148158
int n_dims, int n_steps) {
149159
int warp = threadIdx.x / 32;
150160
int lane = threadIdx.x % 32;
@@ -196,7 +206,7 @@ __global__ void warp_scan_kernel(float *decays, float *impulses,
196206
* writes to output for indices warp_start up to warp_stop.
197207
*/
198208
for (int i = lane; i < n_dims; i += 32) {
199-
float h = 0.0;
209+
scalar_t h = static_cast<scalar_t>(0.0);
200210
if (blockIdx.x == 0 && warp == 0) {
201211
if (initial_state != NULL) {
202212
h = initial_state[i];
@@ -212,34 +222,17 @@ __global__ void warp_scan_kernel(float *decays, float *impulses,
212222
}
213223
}
214224

215-
__global__ void serial_linear_recurrence(float *decays, float *impulses,
216-
float *initial_state, float *out,
217-
int n_dims, int n_steps) {
218-
// computes h_t = lambda_t h{t-1} + x_t
219-
220-
for (int dim_idx = threadIdx.x + blockIdx.x * blockDim.x; dim_idx < n_dims;
221-
dim_idx += blockDim.x * gridDim.x) {
222-
float val = initial_state[dim_idx];
223-
224-
for (int step = 0; step < n_steps; step++) {
225-
int idx = dim_idx + step * n_dims;
226-
val = decays[idx] * val + impulses[idx];
227-
out[idx] = val;
228-
}
229-
}
230-
}
231-
232-
extern "C" {
233225
/*
234226
* This is the main method for the prefix sum kernels.
235227
* decays, impulses, out:
236228
* each a n_dims x n_steps column major matrix located on GPU
237229
* initial_state:
238230
* array of size n_dims located on GPU
239231
*/
240-
void compute_linear_recurrence(float *decays, float *impulses,
241-
float *initial_state, float *out, int n_dims,
242-
int n_steps) {
232+
template <typename scalar_t>
233+
void compute_linear_recurrence(const scalar_t *decays, const scalar_t *impulses,
234+
const scalar_t *initial_state, scalar_t *out,
235+
int n_dims, int n_steps) {
243236
// TODO: query
244237
int n_SMs = 15;
245238
int n_blocks_per_sm = 2;
@@ -251,10 +244,10 @@ void compute_linear_recurrence(float *decays, float *impulses,
251244
// TODO: make user pass in working memory? This allows integration
252245
// with CNMeM (used by Theano)
253246
int reduction_mem_sz = 2 * n_blocks * 33 * n_dims * sizeof(float);
254-
float *d_reduction_mem;
247+
scalar_t *d_reduction_mem;
255248
gpuErrChk(cudaMalloc(&d_reduction_mem, reduction_mem_sz));
256-
float *d_decay_storage = &d_reduction_mem[0 * n_blocks * 33 * n_dims];
257-
float *d_h_storage = &d_reduction_mem[1 * n_blocks * 33 * n_dims];
249+
scalar_t *d_decay_storage = &d_reduction_mem[0 * n_blocks * 33 * n_dims];
250+
scalar_t *d_h_storage = &d_reduction_mem[1 * n_blocks * 33 * n_dims];
258251

259252
// TODO: run kernels on non-default stream?
260253
reduction_kernel<<<n_blocks, 1024>>>(decays, impulses, initial_state,
@@ -271,53 +264,31 @@ void compute_linear_recurrence(float *decays, float *impulses,
271264
gpuErrChk(cudaFree(d_reduction_mem));
272265
}
273266

274-
void compute_serial_linear_recurrence(float *decays, float *impulses,
275-
float *initial_state, float *out,
276-
int n_dims, int n_steps) {
277-
// TODO: query
278-
int n_SMs = 15;
279-
int n_blocks_per_sm = 2;
280-
281-
int n_blocks = n_SMs * n_blocks_per_sm;
282-
serial_linear_recurrence<<<n_blocks, 1024>>>(
283-
decays, impulses, initial_state, out, n_dims, n_steps);
284-
}
267+
at::Tensor scan_cuda_wrapper(const at::Tensor &input, const at::Tensor &weights,
268+
const at::Tensor &initials) {
269+
TORCH_CHECK(input.is_floating_point() || input.is_complex(),
270+
"Input must be floating point or complex");
271+
TORCH_CHECK(initials.scalar_type() == input.scalar_type(),
272+
"Initials must have the same scalar type as input");
273+
TORCH_CHECK(weights.scalar_type() == input.scalar_type(),
274+
"Weights must have the same scalar type as input");
275+
276+
auto input_contiguous = input.transpose(0, 1).contiguous();
277+
auto weights_contiguous = weights.transpose(0, 1).contiguous();
278+
auto output = at::empty_like(input_contiguous);
279+
280+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
281+
282+
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
283+
input.scalar_type(), "compute_linear_recurrence", [&] {
284+
compute_linear_recurrence<scalar_t>(
285+
weights_contiguous.const_data_ptr<scalar_t>(),
286+
input_contiguous.const_data_ptr<scalar_t>(),
287+
initials.const_data_ptr<scalar_t>(),
288+
output.mutable_data_ptr<scalar_t>(), input.size(1),
289+
input.size(0));
290+
});
291+
return output.transpose(0, 1).contiguous();
285292
}
286293

287-
void test() {
288-
int n_dims = 100;
289-
int n_steps = 1000000;
290-
int n_elements = n_dims * n_steps;
291-
292-
float *decays = (float *)calloc(n_elements, sizeof(float));
293-
for (int i = 0; i < n_elements; i++) {
294-
decays[i] = .999;
295-
}
296-
float *d_decays;
297-
gpuErrChk(cudaMalloc(&d_decays, n_elements * sizeof(float)));
298-
gpuErrChk(cudaMemcpy(d_decays, decays, n_elements * sizeof(float),
299-
cudaMemcpyHostToDevice));
300-
301-
float *impulses = (float *)calloc(n_elements, sizeof(float));
302-
for (int i = 0; i < n_dims; i++) {
303-
impulses[i + 0 * n_dims] = 2.0;
304-
}
305-
float *d_impulses;
306-
gpuErrChk(cudaMalloc(&d_impulses, n_elements * sizeof(float)));
307-
gpuErrChk(cudaMemcpy(d_impulses, impulses, n_elements * sizeof(float),
308-
cudaMemcpyHostToDevice));
309-
310-
float *out = (float *)calloc(n_elements, sizeof(float));
311-
float *d_out;
312-
gpuErrChk(cudaMalloc(&d_out, n_elements * sizeof(float)));
313-
gpuErrChk(cudaMemset(d_out, 0, n_elements * sizeof(float)));
314-
315-
compute_linear_recurrence(d_decays, d_impulses, NULL, d_out, n_dims,
316-
n_steps);
317-
gpuErrChk(cudaMemcpy(out, d_out, n_elements * sizeof(float),
318-
cudaMemcpyDeviceToHost));
319-
320-
gpuErrChk(cudaFree(d_decays));
321-
gpuErrChk(cudaFree(d_impulses));
322-
gpuErrChk(cudaFree(d_out));
323-
}
294+
TORCH_LIBRARY_IMPL(torchlpc, CUDA, m) { m.impl("scan", &scan_cuda_wrapper); }

torchlpc/csrc/linear_recurrent_net/linear_recurrence.h

Lines changed: 0 additions & 10 deletions
This file was deleted.

0 commit comments

Comments
 (0)