Skip to content

Commit 4518abf

Browse files
committed
Remove unused ScalarOp.st_impl
1 parent a0fe30d commit 4518abf

File tree

1 file changed

+24
-121
lines changed

1 file changed

+24
-121
lines changed

pytensor/scalar/math.py

Lines changed: 24 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import numpy as np
1212
import scipy.special
13-
import scipy.stats
1413

1514
from pytensor.configdefaults import config
1615
from pytensor.gradient import grad_not_implemented, grad_undefined
@@ -262,12 +261,8 @@ def c_code(self, node, name, inp, out, sub):
262261
class Owens_t(BinaryScalarOp):
263262
nfunc_spec = ("scipy.special.owens_t", 2, 1)
264263

265-
@staticmethod
266-
def st_impl(h, a):
267-
return scipy.special.owens_t(h, a)
268-
269264
def impl(self, h, a):
270-
return Owens_t.st_impl(h, a)
265+
return scipy.special.owens_t(h, a)
271266

272267
def grad(self, inputs, grads):
273268
(h, a) = inputs
@@ -291,12 +286,8 @@ def c_code(self, *args, **kwargs):
291286
class Gamma(UnaryScalarOp):
292287
nfunc_spec = ("scipy.special.gamma", 1, 1)
293288

294-
@staticmethod
295-
def st_impl(x):
296-
return scipy.special.gamma(x)
297-
298289
def impl(self, x):
299-
return Gamma.st_impl(x)
290+
return scipy.special.gamma(x)
300291

301292
def L_op(self, inputs, outputs, gout):
302293
(x,) = inputs
@@ -330,12 +321,8 @@ class GammaLn(UnaryScalarOp):
330321

331322
nfunc_spec = ("scipy.special.gammaln", 1, 1)
332323

333-
@staticmethod
334-
def st_impl(x):
335-
return scipy.special.gammaln(x)
336-
337324
def impl(self, x):
338-
return GammaLn.st_impl(x)
325+
return scipy.special.gammaln(x)
339326

340327
def L_op(self, inputs, outputs, grads):
341328
(x,) = inputs
@@ -374,12 +361,8 @@ class Psi(UnaryScalarOp):
374361

375362
nfunc_spec = ("scipy.special.psi", 1, 1)
376363

377-
@staticmethod
378-
def st_impl(x):
379-
return scipy.special.psi(x)
380-
381364
def impl(self, x):
382-
return Psi.st_impl(x)
365+
return scipy.special.psi(x)
383366

384367
def L_op(self, inputs, outputs, grads):
385368
(x,) = inputs
@@ -465,12 +448,8 @@ class TriGamma(UnaryScalarOp):
465448
466449
"""
467450

468-
@staticmethod
469-
def st_impl(x):
470-
return scipy.special.polygamma(1, x)
471-
472451
def impl(self, x):
473-
return TriGamma.st_impl(x)
452+
return scipy.special.polygamma(1, x)
474453

475454
def L_op(self, inputs, outputs, outputs_gradients):
476455
(x,) = inputs
@@ -568,12 +547,8 @@ def output_types_preference(n_type, x_type):
568547
# Scipy doesn't support it
569548
return upgrade_to_float_no_complex(x_type)
570549

571-
@staticmethod
572-
def st_impl(n, x):
573-
return scipy.special.polygamma(n, x)
574-
575550
def impl(self, n, x):
576-
return PolyGamma.st_impl(n, x)
551+
return scipy.special.polygamma(n, x)
577552

578553
def L_op(self, inputs, outputs, output_gradients):
579554
(n, x) = inputs
@@ -600,12 +575,8 @@ class Chi2SF(BinaryScalarOp):
600575

601576
nfunc_spec = ("scipy.stats.chi2.sf", 2, 1)
602577

603-
@staticmethod
604-
def st_impl(x, k):
605-
return scipy.stats.chi2.sf(x, k)
606-
607578
def impl(self, x, k):
608-
return Chi2SF.st_impl(x, k)
579+
return scipy.stats.chi2.sf(x, k)
609580

610581
def c_support_code(self, **kwargs):
611582
return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8")
@@ -643,12 +614,8 @@ class GammaInc(BinaryScalarOp):
643614

644615
nfunc_spec = ("scipy.special.gammainc", 2, 1)
645616

646-
@staticmethod
647-
def st_impl(k, x):
648-
return scipy.special.gammainc(k, x)
649-
650617
def impl(self, k, x):
651-
return GammaInc.st_impl(k, x)
618+
return scipy.special.gammainc(k, x)
652619

653620
def grad(self, inputs, grads):
654621
(k, x) = inputs
@@ -694,12 +661,8 @@ class GammaIncC(BinaryScalarOp):
694661

695662
nfunc_spec = ("scipy.special.gammaincc", 2, 1)
696663

697-
@staticmethod
698-
def st_impl(k, x):
699-
return scipy.special.gammaincc(k, x)
700-
701664
def impl(self, k, x):
702-
return GammaIncC.st_impl(k, x)
665+
return scipy.special.gammaincc(k, x)
703666

704667
def grad(self, inputs, grads):
705668
(k, x) = inputs
@@ -745,12 +708,8 @@ class GammaIncInv(BinaryScalarOp):
745708

746709
nfunc_spec = ("scipy.special.gammaincinv", 2, 1)
747710

748-
@staticmethod
749-
def st_impl(k, x):
750-
return scipy.special.gammaincinv(k, x)
751-
752711
def impl(self, k, x):
753-
return GammaIncInv.st_impl(k, x)
712+
return scipy.special.gammaincinv(k, x)
754713

755714
def grad(self, inputs, grads):
756715
(k, x) = inputs
@@ -774,12 +733,8 @@ class GammaIncCInv(BinaryScalarOp):
774733

775734
nfunc_spec = ("scipy.special.gammainccinv", 2, 1)
776735

777-
@staticmethod
778-
def st_impl(k, x):
779-
return scipy.special.gammainccinv(k, x)
780-
781736
def impl(self, k, x):
782-
return GammaIncCInv.st_impl(k, x)
737+
return scipy.special.gammainccinv(k, x)
783738

784739
def grad(self, inputs, grads):
785740
(k, x) = inputs
@@ -1013,12 +968,8 @@ class GammaU(BinaryScalarOp):
1013968

1014969
# Note there is no basic SciPy version so no nfunc_spec.
1015970

1016-
@staticmethod
1017-
def st_impl(k, x):
1018-
return scipy.special.gammaincc(k, x) * scipy.special.gamma(k)
1019-
1020971
def impl(self, k, x):
1021-
return GammaU.st_impl(k, x)
972+
return scipy.special.gammaincc(k, x) * scipy.special.gamma(k)
1022973

1023974
def c_support_code(self, **kwargs):
1024975
return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8")
@@ -1049,12 +1000,8 @@ class GammaL(BinaryScalarOp):
10491000

10501001
# Note there is no basic SciPy version so no nfunc_spec.
10511002

1052-
@staticmethod
1053-
def st_impl(k, x):
1054-
return scipy.special.gammainc(k, x) * scipy.special.gamma(k)
1055-
10561003
def impl(self, k, x):
1057-
return GammaL.st_impl(k, x)
1004+
return scipy.special.gammainc(k, x) * scipy.special.gamma(k)
10581005

10591006
def c_support_code(self, **kwargs):
10601007
return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8")
@@ -1085,12 +1032,8 @@ class Jv(BinaryScalarOp):
10851032

10861033
nfunc_spec = ("scipy.special.jv", 2, 1)
10871034

1088-
@staticmethod
1089-
def st_impl(v, x):
1090-
return scipy.special.jv(v, x)
1091-
10921035
def impl(self, v, x):
1093-
return self.st_impl(v, x)
1036+
return scipy.special.jv(v, x)
10941037

10951038
def grad(self, inputs, grads):
10961039
v, x = inputs
@@ -1114,12 +1057,8 @@ class J1(UnaryScalarOp):
11141057

11151058
nfunc_spec = ("scipy.special.j1", 1, 1)
11161059

1117-
@staticmethod
1118-
def st_impl(x):
1119-
return scipy.special.j1(x)
1120-
11211060
def impl(self, x):
1122-
return self.st_impl(x)
1061+
return scipy.special.j1(x)
11231062

11241063
def grad(self, inputs, grads):
11251064
(x,) = inputs
@@ -1145,12 +1084,8 @@ class J0(UnaryScalarOp):
11451084

11461085
nfunc_spec = ("scipy.special.j0", 1, 1)
11471086

1148-
@staticmethod
1149-
def st_impl(x):
1150-
return scipy.special.j0(x)
1151-
11521087
def impl(self, x):
1153-
return self.st_impl(x)
1088+
return scipy.special.j0(x)
11541089

11551090
def grad(self, inp, grads):
11561091
(x,) = inp
@@ -1176,12 +1111,8 @@ class Iv(BinaryScalarOp):
11761111

11771112
nfunc_spec = ("scipy.special.iv", 2, 1)
11781113

1179-
@staticmethod
1180-
def st_impl(v, x):
1181-
return scipy.special.iv(v, x)
1182-
11831114
def impl(self, v, x):
1184-
return self.st_impl(v, x)
1115+
return scipy.special.iv(v, x)
11851116

11861117
def grad(self, inputs, grads):
11871118
v, x = inputs
@@ -1205,12 +1136,8 @@ class I1(UnaryScalarOp):
12051136

12061137
nfunc_spec = ("scipy.special.i1", 1, 1)
12071138

1208-
@staticmethod
1209-
def st_impl(x):
1210-
return scipy.special.i1(x)
1211-
12121139
def impl(self, x):
1213-
return self.st_impl(x)
1140+
return scipy.special.i1(x)
12141141

12151142
def grad(self, inputs, grads):
12161143
(x,) = inputs
@@ -1231,12 +1158,8 @@ class I0(UnaryScalarOp):
12311158

12321159
nfunc_spec = ("scipy.special.i0", 1, 1)
12331160

1234-
@staticmethod
1235-
def st_impl(x):
1236-
return scipy.special.i0(x)
1237-
12381161
def impl(self, x):
1239-
return self.st_impl(x)
1162+
return scipy.special.i0(x)
12401163

12411164
def grad(self, inp, grads):
12421165
(x,) = inp
@@ -1257,12 +1180,8 @@ class Ive(BinaryScalarOp):
12571180

12581181
nfunc_spec = ("scipy.special.ive", 2, 1)
12591182

1260-
@staticmethod
1261-
def st_impl(v, x):
1262-
return scipy.special.ive(v, x)
1263-
12641183
def impl(self, v, x):
1265-
return self.st_impl(v, x)
1184+
return scipy.special.ive(v, x)
12661185

12671186
def grad(self, inputs, grads):
12681187
v, x = inputs
@@ -1286,12 +1205,8 @@ class Kve(BinaryScalarOp):
12861205

12871206
nfunc_spec = ("scipy.special.kve", 2, 1)
12881207

1289-
@staticmethod
1290-
def st_impl(v, x):
1291-
return scipy.special.kve(v, x)
1292-
12931208
def impl(self, v, x):
1294-
return self.st_impl(v, x)
1209+
return scipy.special.kve(v, x)
12951210

12961211
def L_op(self, inputs, outputs, output_grads):
12971212
v, x = inputs
@@ -1372,8 +1287,7 @@ class Softplus(UnaryScalarOp):
13721287
"Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package"
13731288
"""
13741289

1375-
@staticmethod
1376-
def static_impl(x):
1290+
def impl(self, x):
13771291
# If x is an int8 or uint8, numpy.exp will compute the result in
13781292
# half-precision (float16), where we want float32.
13791293
not_int8 = str(getattr(x, "dtype", "")) not in ("int8", "uint8")
@@ -1388,9 +1302,6 @@ def static_impl(x):
13881302
else:
13891303
return x
13901304

1391-
def impl(self, x):
1392-
return Softplus.static_impl(x)
1393-
13941305
def grad(self, inp, grads):
13951306
(x,) = inp
13961307
(gz,) = grads
@@ -1453,16 +1364,12 @@ class Log1mexp(UnaryScalarOp):
14531364
"Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package"
14541365
"""
14551366

1456-
@staticmethod
1457-
def static_impl(x):
1367+
def impl(self, x):
14581368
if x < np.log(0.5):
14591369
return np.log1p(-np.exp(x))
14601370
else:
14611371
return np.log(-np.expm1(x))
14621372

1463-
def impl(self, x):
1464-
return Log1mexp.static_impl(x)
1465-
14661373
def grad(self, inp, grads):
14671374
(x,) = inp
14681375
(gz,) = grads
@@ -1794,12 +1701,8 @@ class Hyp2F1(ScalarOp):
17941701
nin = 4
17951702
nfunc_spec = ("scipy.special.hyp2f1", 4, 1)
17961703

1797-
@staticmethod
1798-
def st_impl(a, b, c, z):
1799-
return scipy.special.hyp2f1(a, b, c, z)
1800-
18011704
def impl(self, a, b, c, z):
1802-
return Hyp2F1.st_impl(a, b, c, z)
1705+
return scipy.special.hyp2f1(a, b, c, z)
18031706

18041707
def grad(self, inputs, grads):
18051708
a, b, c, z = inputs

0 commit comments

Comments
 (0)