1010
1111import numpy as np
1212import scipy .special
13- import scipy .stats
1413
1514from pytensor .configdefaults import config
1615from pytensor .gradient import grad_not_implemented , grad_undefined
@@ -261,12 +260,8 @@ def c_code(self, node, name, inp, out, sub):
261260class Owens_t (BinaryScalarOp ):
262261 nfunc_spec = ("scipy.special.owens_t" , 2 , 1 )
263262
264- @staticmethod
265- def st_impl (h , a ):
266- return scipy .special .owens_t (h , a )
267-
268263 def impl (self , h , a ):
269- return Owens_t . st_impl (h , a )
264+ return scipy . special . owens_t (h , a )
270265
271266 def grad (self , inputs , grads ):
272267 (h , a ) = inputs
@@ -290,12 +285,8 @@ def c_code(self, *args, **kwargs):
290285class Gamma (UnaryScalarOp ):
291286 nfunc_spec = ("scipy.special.gamma" , 1 , 1 )
292287
293- @staticmethod
294- def st_impl (x ):
295- return scipy .special .gamma (x )
296-
297288 def impl (self , x ):
298- return Gamma . st_impl (x )
289+ return scipy . special . gamma (x )
299290
300291 def L_op (self , inputs , outputs , gout ):
301292 (x ,) = inputs
@@ -329,12 +320,8 @@ class GammaLn(UnaryScalarOp):
329320
330321 nfunc_spec = ("scipy.special.gammaln" , 1 , 1 )
331322
332- @staticmethod
333- def st_impl (x ):
334- return scipy .special .gammaln (x )
335-
336323 def impl (self , x ):
337- return GammaLn . st_impl (x )
324+ return scipy . special . gammaln (x )
338325
339326 def L_op (self , inputs , outputs , grads ):
340327 (x ,) = inputs
@@ -373,12 +360,8 @@ class Psi(UnaryScalarOp):
373360
374361 nfunc_spec = ("scipy.special.psi" , 1 , 1 )
375362
376- @staticmethod
377- def st_impl (x ):
378- return scipy .special .psi (x )
379-
380363 def impl (self , x ):
381- return Psi . st_impl (x )
364+ return scipy . special . psi (x )
382365
383366 def L_op (self , inputs , outputs , grads ):
384367 (x ,) = inputs
@@ -464,12 +447,8 @@ class TriGamma(UnaryScalarOp):
464447
465448 """
466449
467- @staticmethod
468- def st_impl (x ):
469- return scipy .special .polygamma (1 , x )
470-
471450 def impl (self , x ):
472- return TriGamma . st_impl ( x )
451+ return scipy . special . polygamma ( 1 , x )
473452
474453 def L_op (self , inputs , outputs , outputs_gradients ):
475454 (x ,) = inputs
@@ -567,12 +546,8 @@ def output_types_preference(n_type, x_type):
567546 # Scipy doesn't support it
568547 return upgrade_to_float_no_complex (x_type )
569548
570- @staticmethod
571- def st_impl (n , x ):
572- return scipy .special .polygamma (n , x )
573-
574549 def impl (self , n , x ):
575- return PolyGamma . st_impl (n , x )
550+ return scipy . special . polygamma (n , x )
576551
577552 def L_op (self , inputs , outputs , output_gradients ):
578553 (n , x ) = inputs
@@ -598,12 +573,8 @@ class GammaInc(BinaryScalarOp):
598573
599574 nfunc_spec = ("scipy.special.gammainc" , 2 , 1 )
600575
601- @staticmethod
602- def st_impl (k , x ):
603- return scipy .special .gammainc (k , x )
604-
605576 def impl (self , k , x ):
606- return GammaInc . st_impl (k , x )
577+ return scipy . special . gammainc (k , x )
607578
608579 def grad (self , inputs , grads ):
609580 (k , x ) = inputs
@@ -649,12 +620,8 @@ class GammaIncC(BinaryScalarOp):
649620
650621 nfunc_spec = ("scipy.special.gammaincc" , 2 , 1 )
651622
652- @staticmethod
653- def st_impl (k , x ):
654- return scipy .special .gammaincc (k , x )
655-
656623 def impl (self , k , x ):
657- return GammaIncC . st_impl (k , x )
624+ return scipy . special . gammaincc (k , x )
658625
659626 def grad (self , inputs , grads ):
660627 (k , x ) = inputs
@@ -700,12 +667,8 @@ class GammaIncInv(BinaryScalarOp):
700667
701668 nfunc_spec = ("scipy.special.gammaincinv" , 2 , 1 )
702669
703- @staticmethod
704- def st_impl (k , x ):
705- return scipy .special .gammaincinv (k , x )
706-
707670 def impl (self , k , x ):
708- return GammaIncInv . st_impl (k , x )
671+ return scipy . special . gammaincinv (k , x )
709672
710673 def grad (self , inputs , grads ):
711674 (k , x ) = inputs
@@ -729,12 +692,8 @@ class GammaIncCInv(BinaryScalarOp):
729692
730693 nfunc_spec = ("scipy.special.gammainccinv" , 2 , 1 )
731694
732- @staticmethod
733- def st_impl (k , x ):
734- return scipy .special .gammainccinv (k , x )
735-
736695 def impl (self , k , x ):
737- return GammaIncCInv . st_impl (k , x )
696+ return scipy . special . gammainccinv (k , x )
738697
739698 def grad (self , inputs , grads ):
740699 (k , x ) = inputs
@@ -968,12 +927,8 @@ class GammaU(BinaryScalarOp):
968927
969928 # Note there is no basic SciPy version so no nfunc_spec.
970929
971- @staticmethod
972- def st_impl (k , x ):
973- return scipy .special .gammaincc (k , x ) * scipy .special .gamma (k )
974-
975930 def impl (self , k , x ):
976- return GammaU . st_impl (k , x )
931+ return scipy . special . gammaincc (k , x ) * scipy . special . gamma ( k )
977932
978933 def c_support_code (self , ** kwargs ):
979934 return (C_CODE_PATH / "gamma.c" ).read_text (encoding = "utf-8" )
@@ -1004,12 +959,8 @@ class GammaL(BinaryScalarOp):
1004959
1005960 # Note there is no basic SciPy version so no nfunc_spec.
1006961
1007- @staticmethod
1008- def st_impl (k , x ):
1009- return scipy .special .gammainc (k , x ) * scipy .special .gamma (k )
1010-
1011962 def impl (self , k , x ):
1012- return GammaL . st_impl (k , x )
963+ return scipy . special . gammainc (k , x ) * scipy . special . gamma ( k )
1013964
1014965 def c_support_code (self , ** kwargs ):
1015966 return (C_CODE_PATH / "gamma.c" ).read_text (encoding = "utf-8" )
@@ -1040,12 +991,8 @@ class Jv(BinaryScalarOp):
1040991
1041992 nfunc_spec = ("scipy.special.jv" , 2 , 1 )
1042993
1043- @staticmethod
1044- def st_impl (v , x ):
1045- return scipy .special .jv (v , x )
1046-
1047994 def impl (self , v , x ):
1048- return self . st_impl (v , x )
995+ return scipy . special . jv (v , x )
1049996
1050997 def grad (self , inputs , grads ):
1051998 v , x = inputs
@@ -1069,12 +1016,8 @@ class J1(UnaryScalarOp):
10691016
10701017 nfunc_spec = ("scipy.special.j1" , 1 , 1 )
10711018
1072- @staticmethod
1073- def st_impl (x ):
1074- return scipy .special .j1 (x )
1075-
10761019 def impl (self , x ):
1077- return self . st_impl (x )
1020+ return scipy . special . j1 (x )
10781021
10791022 def grad (self , inputs , grads ):
10801023 (x ,) = inputs
@@ -1100,12 +1043,8 @@ class J0(UnaryScalarOp):
11001043
11011044 nfunc_spec = ("scipy.special.j0" , 1 , 1 )
11021045
1103- @staticmethod
1104- def st_impl (x ):
1105- return scipy .special .j0 (x )
1106-
11071046 def impl (self , x ):
1108- return self . st_impl (x )
1047+ return scipy . special . j0 (x )
11091048
11101049 def grad (self , inp , grads ):
11111050 (x ,) = inp
@@ -1131,12 +1070,8 @@ class Iv(BinaryScalarOp):
11311070
11321071 nfunc_spec = ("scipy.special.iv" , 2 , 1 )
11331072
1134- @staticmethod
1135- def st_impl (v , x ):
1136- return scipy .special .iv (v , x )
1137-
11381073 def impl (self , v , x ):
1139- return self . st_impl (v , x )
1074+ return scipy . special . iv (v , x )
11401075
11411076 def grad (self , inputs , grads ):
11421077 v , x = inputs
@@ -1160,12 +1095,8 @@ class I1(UnaryScalarOp):
11601095
11611096 nfunc_spec = ("scipy.special.i1" , 1 , 1 )
11621097
1163- @staticmethod
1164- def st_impl (x ):
1165- return scipy .special .i1 (x )
1166-
11671098 def impl (self , x ):
1168- return self . st_impl (x )
1099+ return scipy . special . i1 (x )
11691100
11701101 def grad (self , inputs , grads ):
11711102 (x ,) = inputs
@@ -1186,12 +1117,8 @@ class I0(UnaryScalarOp):
11861117
11871118 nfunc_spec = ("scipy.special.i0" , 1 , 1 )
11881119
1189- @staticmethod
1190- def st_impl (x ):
1191- return scipy .special .i0 (x )
1192-
11931120 def impl (self , x ):
1194- return self . st_impl (x )
1121+ return scipy . special . i0 (x )
11951122
11961123 def grad (self , inp , grads ):
11971124 (x ,) = inp
@@ -1212,12 +1139,8 @@ class Ive(BinaryScalarOp):
12121139
12131140 nfunc_spec = ("scipy.special.ive" , 2 , 1 )
12141141
1215- @staticmethod
1216- def st_impl (v , x ):
1217- return scipy .special .ive (v , x )
1218-
12191142 def impl (self , v , x ):
1220- return self . st_impl (v , x )
1143+ return scipy . special . ive (v , x )
12211144
12221145 def grad (self , inputs , grads ):
12231146 v , x = inputs
@@ -1241,12 +1164,8 @@ class Kve(BinaryScalarOp):
12411164
12421165 nfunc_spec = ("scipy.special.kve" , 2 , 1 )
12431166
1244- @staticmethod
1245- def st_impl (v , x ):
1246- return scipy .special .kve (v , x )
1247-
12481167 def impl (self , v , x ):
1249- return self . st_impl (v , x )
1168+ return scipy . special . kve (v , x )
12501169
12511170 def L_op (self , inputs , outputs , output_grads ):
12521171 v , x = inputs
@@ -1327,8 +1246,7 @@ class Softplus(UnaryScalarOp):
13271246 "Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package"
13281247 """
13291248
1330- @staticmethod
1331- def static_impl (x ):
1249+ def impl (self , x ):
13321250 # If x is an int8 or uint8, numpy.exp will compute the result in
13331251 # half-precision (float16), where we want float32.
13341252 not_int8 = str (getattr (x , "dtype" , "" )) not in ("int8" , "uint8" )
@@ -1343,9 +1261,6 @@ def static_impl(x):
13431261 else :
13441262 return x
13451263
1346- def impl (self , x ):
1347- return Softplus .static_impl (x )
1348-
13491264 def grad (self , inp , grads ):
13501265 (x ,) = inp
13511266 (gz ,) = grads
@@ -1408,16 +1323,12 @@ class Log1mexp(UnaryScalarOp):
14081323 "Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package"
14091324 """
14101325
1411- @staticmethod
1412- def static_impl (x ):
1326+ def impl (self , x ):
14131327 if x < np .log (0.5 ):
14141328 return np .log1p (- np .exp (x ))
14151329 else :
14161330 return np .log (- np .expm1 (x ))
14171331
1418- def impl (self , x ):
1419- return Log1mexp .static_impl (x )
1420-
14211332 def grad (self , inp , grads ):
14221333 (x ,) = inp
14231334 (gz ,) = grads
@@ -1749,12 +1660,8 @@ class Hyp2F1(ScalarOp):
17491660 nin = 4
17501661 nfunc_spec = ("scipy.special.hyp2f1" , 4 , 1 )
17511662
1752- @staticmethod
1753- def st_impl (a , b , c , z ):
1754- return scipy .special .hyp2f1 (a , b , c , z )
1755-
17561663 def impl (self , a , b , c , z ):
1757- return Hyp2F1 . st_impl (a , b , c , z )
1664+ return scipy . special . hyp2f1 (a , b , c , z )
17581665
17591666 def grad (self , inputs , grads ):
17601667 a , b , c , z = inputs
0 commit comments