Skip to content

Commit a4fd535

Browse files
committed
refactor: update function calls to use unified 'scan' operation
1 parent 6c8570c commit a4fd535

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

torchlpc/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def forward(x: torch.Tensor, A: torch.Tensor, zi: torch.Tensor) -> torch.Tensor:
162162
if x.is_cuda:
163163
y = lpc_cuda(x.detach(), A.detach(), zi.detach())
164164
elif EXTENSION_LOADED:
165-
y = torch.ops.torchlpc.lpc_cpu(x, A, zi)
165+
y = torch.ops.torchlpc.lpc(x, A, zi)
166166
else:
167167
warnings.warn(
168168
"Cannot find custom extension. Falling back to Numba implementation which will be deprecated in v1.0."

torchlpc/recurrence.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,25 @@ def forward(
1919
n_dims, n_steps = decay.shape
2020
if decay.is_cuda:
2121
if n_dims * WARPSIZE < n_steps:
22-
out = torch.empty_like(impulse)
23-
compute_linear_recurrence(
24-
cuda.as_cuda_array(decay.detach()),
25-
cuda.as_cuda_array(impulse.detach()),
26-
cuda.as_cuda_array(initial_state.detach()),
27-
cuda.as_cuda_array(out),
28-
n_dims,
29-
n_steps,
30-
)
22+
if EXTENSION_LOADED:
23+
out = torch.ops.torchlpc.scan(impulse, decay, initial_state)
24+
else:
25+
out = torch.empty_like(impulse)
26+
compute_linear_recurrence(
27+
cuda.as_cuda_array(decay.detach()),
28+
cuda.as_cuda_array(impulse.detach()),
29+
cuda.as_cuda_array(initial_state.detach()),
30+
cuda.as_cuda_array(out),
31+
n_dims,
32+
n_steps,
33+
)
3134
else:
3235
out = lpc_cuda(impulse, -decay.unsqueeze(2), initial_state.unsqueeze(1))
3336
else:
3437
num_threads = torch.get_num_threads()
3538
# This is just a rough estimation of the computational cost
3639
if EXTENSION_LOADED and min(n_dims, num_threads) < num_threads / 3:
37-
out = torch.ops.torchlpc.scan_cpu(impulse, decay, initial_state)
40+
out = torch.ops.torchlpc.scan(impulse, decay, initial_state)
3841
else:
3942
out = torch.from_numpy(
4043
lpc_np(

0 commit comments

Comments
 (0)