|
9 | 9 | from . import EXTENSION_LOADED |
10 | 10 |
|
11 | 11 |
|
| 12 | +def _cuda_recurrence( |
| 13 | + impulse: torch.Tensor, decay: torch.Tensor, initial_state: torch.Tensor |
| 14 | +) -> torch.Tensor: |
| 15 | + n_dims, n_steps = decay.shape |
| 16 | + if n_dims * WARPSIZE < n_steps: |
| 17 | + if EXTENSION_LOADED: |
| 18 | + runner = torch.ops.torchlpc.scan |
| 19 | + else: |
| 20 | + |
| 21 | + def runner(impulse, decay, initial_state): |
| 22 | + out = torch.empty_like(impulse) |
| 23 | + compute_linear_recurrence( |
| 24 | + cuda.as_cuda_array(decay.detach()), |
| 25 | + cuda.as_cuda_array(impulse.detach()), |
| 26 | + cuda.as_cuda_array(initial_state.detach()), |
| 27 | + cuda.as_cuda_array(out), |
| 28 | + n_dims, |
| 29 | + n_steps, |
| 30 | + ) |
| 31 | + return out |
| 32 | + |
| 33 | + else: |
| 34 | + runner = lambda impulse, decay, initial_state: lpc_cuda( |
| 35 | + impulse, -decay.unsqueeze(2), initial_state.unsqueeze(1) |
| 36 | + ) |
| 37 | + return runner(impulse, decay, initial_state) |
| 38 | + |
| 39 | + |
| 40 | +def _cpu_recurrence( |
| 41 | + impulse: torch.Tensor, decay: torch.Tensor, initial_state: torch.Tensor |
| 42 | +) -> torch.Tensor: |
| 43 | + num_threads = torch.get_num_threads() |
| 44 | + n_dims, _ = decay.shape |
| 45 | + # This is just a rough estimation of the computational cost |
| 46 | + if EXTENSION_LOADED and min(n_dims, num_threads) < num_threads / 3: |
| 47 | + runner = torch.ops.torchlpc.scan |
| 48 | + else: |
| 49 | + runner = lambda impulse, decay, initial_state: torch.from_numpy( |
| 50 | + lpc_np( |
| 51 | + impulse.detach().numpy(), |
| 52 | + -decay.unsqueeze(2).detach().numpy(), |
| 53 | + initial_state.unsqueeze(1).detach().numpy(), |
| 54 | + ) |
| 55 | + ) |
| 56 | + return runner(impulse, decay, initial_state) |
| 57 | + |
| 58 | + |
12 | 59 | class Recurrence(Function): |
13 | 60 | @staticmethod |
14 | 61 | def forward( |
15 | 62 | decay: torch.Tensor, |
16 | 63 | impulse: torch.Tensor, |
17 | 64 | initial_state: torch.Tensor, |
18 | 65 | ) -> torch.Tensor: |
19 | | - n_dims, n_steps = decay.shape |
20 | 66 | if decay.is_cuda: |
21 | | - if n_dims * WARPSIZE < n_steps: |
22 | | - if EXTENSION_LOADED: |
23 | | - out = torch.ops.torchlpc.scan(impulse, decay, initial_state) |
24 | | - else: |
25 | | - out = torch.empty_like(impulse) |
26 | | - compute_linear_recurrence( |
27 | | - cuda.as_cuda_array(decay.detach()), |
28 | | - cuda.as_cuda_array(impulse.detach()), |
29 | | - cuda.as_cuda_array(initial_state.detach()), |
30 | | - cuda.as_cuda_array(out), |
31 | | - n_dims, |
32 | | - n_steps, |
33 | | - ) |
34 | | - else: |
35 | | - out = lpc_cuda(impulse, -decay.unsqueeze(2), initial_state.unsqueeze(1)) |
| 67 | + out = _cuda_recurrence(impulse, decay, initial_state) |
36 | 68 | else: |
37 | | - num_threads = torch.get_num_threads() |
38 | | - # This is just a rough estimation of the computational cost |
39 | | - if EXTENSION_LOADED and min(n_dims, num_threads) < num_threads / 3: |
40 | | - out = torch.ops.torchlpc.scan(impulse, decay, initial_state) |
41 | | - else: |
42 | | - out = torch.from_numpy( |
43 | | - lpc_np( |
44 | | - impulse.detach().numpy(), |
45 | | - -decay.unsqueeze(2).detach().numpy(), |
46 | | - initial_state.unsqueeze(1).detach().numpy(), |
47 | | - ) |
48 | | - ) |
| 69 | + out = _cpu_recurrence(impulse, decay, initial_state) |
49 | 70 | return out |
50 | 71 |
|
51 | 72 | @staticmethod |
|
0 commit comments