1010
1111import numpy as np
1212import scipy .special
13- import scipy .stats
1413
1514from pytensor .configdefaults import config
1615from pytensor .gradient import grad_not_implemented , grad_undefined
@@ -262,12 +261,8 @@ def c_code(self, node, name, inp, out, sub):
262261class Owens_t (BinaryScalarOp ):
263262 nfunc_spec = ("scipy.special.owens_t" , 2 , 1 )
264263
265- @staticmethod
266- def st_impl (h , a ):
267- return scipy .special .owens_t (h , a )
268-
269264 def impl (self , h , a ):
270- return Owens_t . st_impl (h , a )
265+ return scipy . special . owens_t (h , a )
271266
272267 def grad (self , inputs , grads ):
273268 (h , a ) = inputs
@@ -291,12 +286,8 @@ def c_code(self, *args, **kwargs):
291286class Gamma (UnaryScalarOp ):
292287 nfunc_spec = ("scipy.special.gamma" , 1 , 1 )
293288
294- @staticmethod
295- def st_impl (x ):
296- return scipy .special .gamma (x )
297-
298289 def impl (self , x ):
299- return Gamma . st_impl (x )
290+ return scipy . special . gamma (x )
300291
301292 def L_op (self , inputs , outputs , gout ):
302293 (x ,) = inputs
@@ -330,12 +321,8 @@ class GammaLn(UnaryScalarOp):
330321
331322 nfunc_spec = ("scipy.special.gammaln" , 1 , 1 )
332323
333- @staticmethod
334- def st_impl (x ):
335- return scipy .special .gammaln (x )
336-
337324 def impl (self , x ):
338- return GammaLn . st_impl (x )
325+ return scipy . special . gammaln (x )
339326
340327 def L_op (self , inputs , outputs , grads ):
341328 (x ,) = inputs
@@ -374,12 +361,8 @@ class Psi(UnaryScalarOp):
374361
375362 nfunc_spec = ("scipy.special.psi" , 1 , 1 )
376363
377- @staticmethod
378- def st_impl (x ):
379- return scipy .special .psi (x )
380-
381364 def impl (self , x ):
382- return Psi . st_impl (x )
365+ return scipy . special . psi (x )
383366
384367 def L_op (self , inputs , outputs , grads ):
385368 (x ,) = inputs
@@ -465,12 +448,8 @@ class TriGamma(UnaryScalarOp):
465448
466449 """
467450
468- @staticmethod
469- def st_impl (x ):
470- return scipy .special .polygamma (1 , x )
471-
472451 def impl (self , x ):
473- return TriGamma . st_impl ( x )
452+ return scipy . special . polygamma ( 1 , x )
474453
475454 def L_op (self , inputs , outputs , outputs_gradients ):
476455 (x ,) = inputs
@@ -568,12 +547,8 @@ def output_types_preference(n_type, x_type):
568547 # Scipy doesn't support it
569548 return upgrade_to_float_no_complex (x_type )
570549
571- @staticmethod
572- def st_impl (n , x ):
573- return scipy .special .polygamma (n , x )
574-
575550 def impl (self , n , x ):
576- return PolyGamma . st_impl (n , x )
551+ return scipy . special . polygamma (n , x )
577552
578553 def L_op (self , inputs , outputs , output_gradients ):
579554 (n , x ) = inputs
@@ -600,12 +575,8 @@ class Chi2SF(BinaryScalarOp):
600575
601576 nfunc_spec = ("scipy.stats.chi2.sf" , 2 , 1 )
602577
603- @staticmethod
604- def st_impl (x , k ):
605- return scipy .stats .chi2 .sf (x , k )
606-
607578 def impl (self , x , k ):
608- return Chi2SF . st_impl (x , k )
579+ return scipy . stats . chi2 . sf (x , k )
609580
610581 def c_support_code (self , ** kwargs ):
611582 return (C_CODE_PATH / "gamma.c" ).read_text (encoding = "utf-8" )
@@ -643,12 +614,8 @@ class GammaInc(BinaryScalarOp):
643614
644615 nfunc_spec = ("scipy.special.gammainc" , 2 , 1 )
645616
646- @staticmethod
647- def st_impl (k , x ):
648- return scipy .special .gammainc (k , x )
649-
650617 def impl (self , k , x ):
651- return GammaInc . st_impl (k , x )
618+ return scipy . special . gammainc (k , x )
652619
653620 def grad (self , inputs , grads ):
654621 (k , x ) = inputs
@@ -694,12 +661,8 @@ class GammaIncC(BinaryScalarOp):
694661
695662 nfunc_spec = ("scipy.special.gammaincc" , 2 , 1 )
696663
697- @staticmethod
698- def st_impl (k , x ):
699- return scipy .special .gammaincc (k , x )
700-
701664 def impl (self , k , x ):
702- return GammaIncC . st_impl (k , x )
665+ return scipy . special . gammaincc (k , x )
703666
704667 def grad (self , inputs , grads ):
705668 (k , x ) = inputs
@@ -745,12 +708,8 @@ class GammaIncInv(BinaryScalarOp):
745708
746709 nfunc_spec = ("scipy.special.gammaincinv" , 2 , 1 )
747710
748- @staticmethod
749- def st_impl (k , x ):
750- return scipy .special .gammaincinv (k , x )
751-
752711 def impl (self , k , x ):
753- return GammaIncInv . st_impl (k , x )
712+ return scipy . special . gammaincinv (k , x )
754713
755714 def grad (self , inputs , grads ):
756715 (k , x ) = inputs
@@ -774,12 +733,8 @@ class GammaIncCInv(BinaryScalarOp):
774733
775734 nfunc_spec = ("scipy.special.gammainccinv" , 2 , 1 )
776735
777- @staticmethod
778- def st_impl (k , x ):
779- return scipy .special .gammainccinv (k , x )
780-
781736 def impl (self , k , x ):
782- return GammaIncCInv . st_impl (k , x )
737+ return scipy . special . gammainccinv (k , x )
783738
784739 def grad (self , inputs , grads ):
785740 (k , x ) = inputs
@@ -1013,12 +968,8 @@ class GammaU(BinaryScalarOp):
1013968
1014969 # Note there is no basic SciPy version so no nfunc_spec.
1015970
1016- @staticmethod
1017- def st_impl (k , x ):
1018- return scipy .special .gammaincc (k , x ) * scipy .special .gamma (k )
1019-
1020971 def impl (self , k , x ):
1021- return GammaU . st_impl (k , x )
972+ return scipy . special . gammaincc (k , x ) * scipy . special . gamma ( k )
1022973
1023974 def c_support_code (self , ** kwargs ):
1024975 return (C_CODE_PATH / "gamma.c" ).read_text (encoding = "utf-8" )
@@ -1049,12 +1000,8 @@ class GammaL(BinaryScalarOp):
10491000
10501001 # Note there is no basic SciPy version so no nfunc_spec.
10511002
1052- @staticmethod
1053- def st_impl (k , x ):
1054- return scipy .special .gammainc (k , x ) * scipy .special .gamma (k )
1055-
10561003 def impl (self , k , x ):
1057- return GammaL . st_impl (k , x )
1004+ return scipy . special . gammainc (k , x ) * scipy . special . gamma ( k )
10581005
10591006 def c_support_code (self , ** kwargs ):
10601007 return (C_CODE_PATH / "gamma.c" ).read_text (encoding = "utf-8" )
@@ -1085,12 +1032,8 @@ class Jv(BinaryScalarOp):
10851032
10861033 nfunc_spec = ("scipy.special.jv" , 2 , 1 )
10871034
1088- @staticmethod
1089- def st_impl (v , x ):
1090- return scipy .special .jv (v , x )
1091-
10921035 def impl (self , v , x ):
1093- return self . st_impl (v , x )
1036+ return scipy . special . jv (v , x )
10941037
10951038 def grad (self , inputs , grads ):
10961039 v , x = inputs
@@ -1114,12 +1057,8 @@ class J1(UnaryScalarOp):
11141057
11151058 nfunc_spec = ("scipy.special.j1" , 1 , 1 )
11161059
1117- @staticmethod
1118- def st_impl (x ):
1119- return scipy .special .j1 (x )
1120-
11211060 def impl (self , x ):
1122- return self . st_impl (x )
1061+ return scipy . special . j1 (x )
11231062
11241063 def grad (self , inputs , grads ):
11251064 (x ,) = inputs
@@ -1145,12 +1084,8 @@ class J0(UnaryScalarOp):
11451084
11461085 nfunc_spec = ("scipy.special.j0" , 1 , 1 )
11471086
1148- @staticmethod
1149- def st_impl (x ):
1150- return scipy .special .j0 (x )
1151-
11521087 def impl (self , x ):
1153- return self . st_impl (x )
1088+ return scipy . special . j0 (x )
11541089
11551090 def grad (self , inp , grads ):
11561091 (x ,) = inp
@@ -1176,12 +1111,8 @@ class Iv(BinaryScalarOp):
11761111
11771112 nfunc_spec = ("scipy.special.iv" , 2 , 1 )
11781113
1179- @staticmethod
1180- def st_impl (v , x ):
1181- return scipy .special .iv (v , x )
1182-
11831114 def impl (self , v , x ):
1184- return self . st_impl (v , x )
1115+ return scipy . special . iv (v , x )
11851116
11861117 def grad (self , inputs , grads ):
11871118 v , x = inputs
@@ -1205,12 +1136,8 @@ class I1(UnaryScalarOp):
12051136
12061137 nfunc_spec = ("scipy.special.i1" , 1 , 1 )
12071138
1208- @staticmethod
1209- def st_impl (x ):
1210- return scipy .special .i1 (x )
1211-
12121139 def impl (self , x ):
1213- return self . st_impl (x )
1140+ return scipy . special . i1 (x )
12141141
12151142 def grad (self , inputs , grads ):
12161143 (x ,) = inputs
@@ -1231,12 +1158,8 @@ class I0(UnaryScalarOp):
12311158
12321159 nfunc_spec = ("scipy.special.i0" , 1 , 1 )
12331160
1234- @staticmethod
1235- def st_impl (x ):
1236- return scipy .special .i0 (x )
1237-
12381161 def impl (self , x ):
1239- return self . st_impl (x )
1162+ return scipy . special . i0 (x )
12401163
12411164 def grad (self , inp , grads ):
12421165 (x ,) = inp
@@ -1257,12 +1180,8 @@ class Ive(BinaryScalarOp):
12571180
12581181 nfunc_spec = ("scipy.special.ive" , 2 , 1 )
12591182
1260- @staticmethod
1261- def st_impl (v , x ):
1262- return scipy .special .ive (v , x )
1263-
12641183 def impl (self , v , x ):
1265- return self . st_impl (v , x )
1184+ return scipy . special . ive (v , x )
12661185
12671186 def grad (self , inputs , grads ):
12681187 v , x = inputs
@@ -1286,12 +1205,8 @@ class Kve(BinaryScalarOp):
12861205
12871206 nfunc_spec = ("scipy.special.kve" , 2 , 1 )
12881207
1289- @staticmethod
1290- def st_impl (v , x ):
1291- return scipy .special .kve (v , x )
1292-
12931208 def impl (self , v , x ):
1294- return self . st_impl (v , x )
1209+ return scipy . special . kve (v , x )
12951210
12961211 def L_op (self , inputs , outputs , output_grads ):
12971212 v , x = inputs
@@ -1372,8 +1287,7 @@ class Softplus(UnaryScalarOp):
13721287 "Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package"
13731288 """
13741289
1375- @staticmethod
1376- def static_impl (x ):
1290+ def impl (self , x ):
13771291 # If x is an int8 or uint8, numpy.exp will compute the result in
13781292 # half-precision (float16), where we want float32.
13791293 not_int8 = str (getattr (x , "dtype" , "" )) not in ("int8" , "uint8" )
@@ -1388,9 +1302,6 @@ def static_impl(x):
13881302 else :
13891303 return x
13901304
1391- def impl (self , x ):
1392- return Softplus .static_impl (x )
1393-
13941305 def grad (self , inp , grads ):
13951306 (x ,) = inp
13961307 (gz ,) = grads
@@ -1453,16 +1364,12 @@ class Log1mexp(UnaryScalarOp):
14531364 "Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package"
14541365 """
14551366
1456- @staticmethod
1457- def static_impl (x ):
1367+ def impl (self , x ):
14581368 if x < np .log (0.5 ):
14591369 return np .log1p (- np .exp (x ))
14601370 else :
14611371 return np .log (- np .expm1 (x ))
14621372
1463- def impl (self , x ):
1464- return Log1mexp .static_impl (x )
1465-
14661373 def grad (self , inp , grads ):
14671374 (x ,) = inp
14681375 (gz ,) = grads
@@ -1794,12 +1701,8 @@ class Hyp2F1(ScalarOp):
17941701 nin = 4
17951702 nfunc_spec = ("scipy.special.hyp2f1" , 4 , 1 )
17961703
1797- @staticmethod
1798- def st_impl (a , b , c , z ):
1799- return scipy .special .hyp2f1 (a , b , c , z )
1800-
18011704 def impl (self , a , b , c , z ):
1802- return Hyp2F1 . st_impl (a , b , c , z )
1705+ return scipy . special . hyp2f1 (a , b , c , z )
18031706
18041707 def grad (self , inputs , grads ):
18051708 a , b , c , z = inputs
0 commit comments