Skip to content

Commit 6d4a2de

Browse files
committed
refactor: implement separate recurrence functions for improved clarity and maintainability
1 parent 206497e commit 6d4a2de

File tree

1 file changed

+49
-28
lines changed

1 file changed

+49
-28
lines changed

torchlpc/recurrence.py

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,43 +9,64 @@
99
from . import EXTENSION_LOADED
1010

1111

12+
def _cuda_recurrence(
13+
impulse: torch.Tensor, decay: torch.Tensor, initial_state: torch.Tensor
14+
) -> torch.Tensor:
15+
n_dims, n_steps = decay.shape
16+
if n_dims * WARPSIZE < n_steps:
17+
if EXTENSION_LOADED:
18+
runner = torch.ops.torchlpc.scan
19+
else:
20+
21+
def runner(impulse, decay, initial_state):
22+
out = torch.empty_like(impulse)
23+
compute_linear_recurrence(
24+
cuda.as_cuda_array(decay.detach()),
25+
cuda.as_cuda_array(impulse.detach()),
26+
cuda.as_cuda_array(initial_state.detach()),
27+
cuda.as_cuda_array(out),
28+
n_dims,
29+
n_steps,
30+
)
31+
return out
32+
33+
else:
34+
runner = lambda impulse, decay, initial_state: lpc_cuda(
35+
impulse, -decay.unsqueeze(2), initial_state.unsqueeze(1)
36+
)
37+
return runner(impulse, decay, initial_state)
38+
39+
40+
def _cpu_recurrence(
41+
impulse: torch.Tensor, decay: torch.Tensor, initial_state: torch.Tensor
42+
) -> torch.Tensor:
43+
num_threads = torch.get_num_threads()
44+
n_dims, _ = decay.shape
45+
# This is just a rough estimation of the computational cost
46+
if EXTENSION_LOADED and min(n_dims, num_threads) < num_threads / 3:
47+
runner = torch.ops.torchlpc.scan
48+
else:
49+
runner = lambda impulse, decay, initial_state: torch.from_numpy(
50+
lpc_np(
51+
impulse.detach().numpy(),
52+
-decay.unsqueeze(2).detach().numpy(),
53+
initial_state.unsqueeze(1).detach().numpy(),
54+
)
55+
)
56+
return runner(impulse, decay, initial_state)
57+
58+
1259
class Recurrence(Function):
1360
@staticmethod
1461
def forward(
1562
decay: torch.Tensor,
1663
impulse: torch.Tensor,
1764
initial_state: torch.Tensor,
1865
) -> torch.Tensor:
19-
n_dims, n_steps = decay.shape
2066
if decay.is_cuda:
21-
if n_dims * WARPSIZE < n_steps:
22-
if EXTENSION_LOADED:
23-
out = torch.ops.torchlpc.scan(impulse, decay, initial_state)
24-
else:
25-
out = torch.empty_like(impulse)
26-
compute_linear_recurrence(
27-
cuda.as_cuda_array(decay.detach()),
28-
cuda.as_cuda_array(impulse.detach()),
29-
cuda.as_cuda_array(initial_state.detach()),
30-
cuda.as_cuda_array(out),
31-
n_dims,
32-
n_steps,
33-
)
34-
else:
35-
out = lpc_cuda(impulse, -decay.unsqueeze(2), initial_state.unsqueeze(1))
67+
out = _cuda_recurrence(impulse, decay, initial_state)
3668
else:
37-
num_threads = torch.get_num_threads()
38-
# This is just a rough estimation of the computational cost
39-
if EXTENSION_LOADED and min(n_dims, num_threads) < num_threads / 3:
40-
out = torch.ops.torchlpc.scan(impulse, decay, initial_state)
41-
else:
42-
out = torch.from_numpy(
43-
lpc_np(
44-
impulse.detach().numpy(),
45-
-decay.unsqueeze(2).detach().numpy(),
46-
initial_state.unsqueeze(1).detach().numpy(),
47-
)
48-
)
69+
out = _cpu_recurrence(impulse, decay, initial_state)
4970
return out
5071

5172
@staticmethod

0 commit comments

Comments
 (0)