@@ -19,22 +19,25 @@ def forward(
1919 n_dims , n_steps = decay .shape
2020 if decay .is_cuda :
2121 if n_dims * WARPSIZE < n_steps :
22- out = torch .empty_like (impulse )
23- compute_linear_recurrence (
24- cuda .as_cuda_array (decay .detach ()),
25- cuda .as_cuda_array (impulse .detach ()),
26- cuda .as_cuda_array (initial_state .detach ()),
27- cuda .as_cuda_array (out ),
28- n_dims ,
29- n_steps ,
30- )
22+ if EXTENSION_LOADED :
23+ out = torch .ops .torchlpc .scan (impulse , decay , initial_state )
24+ else :
25+ out = torch .empty_like (impulse )
26+ compute_linear_recurrence (
27+ cuda .as_cuda_array (decay .detach ()),
28+ cuda .as_cuda_array (impulse .detach ()),
29+ cuda .as_cuda_array (initial_state .detach ()),
30+ cuda .as_cuda_array (out ),
31+ n_dims ,
32+ n_steps ,
33+ )
3134 else :
3235 out = lpc_cuda (impulse , - decay .unsqueeze (2 ), initial_state .unsqueeze (1 ))
3336 else :
3437 num_threads = torch .get_num_threads ()
3538 # This is just a rough estimation of the computational cost
3639 if EXTENSION_LOADED and min (n_dims , num_threads ) < num_threads / 3 :
37- out = torch .ops .torchlpc .scan_cpu (impulse , decay , initial_state )
40+ out = torch .ops .torchlpc .scan (impulse , decay , initial_state )
3841 else :
3942 out = torch .from_numpy (
4043 lpc_np (
0 commit comments