Skip to content

Commit 289067c

Browse files
committed
More direct access to special functions
1 parent 5010892 commit 289067c

File tree

1 file changed

+33
-33
lines changed

1 file changed

+33
-33
lines changed

pytensor/scalar/math.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from textwrap import dedent
1010

1111
import numpy as np
12-
import scipy.special
12+
from scipy import special
1313

1414
from pytensor.configdefaults import config
1515
from pytensor.gradient import grad_not_implemented, grad_undefined
@@ -53,7 +53,7 @@ class Erf(UnaryScalarOp):
5353
nfunc_spec = ("scipy.special.erf", 1, 1)
5454

5555
def impl(self, x):
56-
return scipy.special.erf(x)
56+
return special.erf(x)
5757

5858
def L_op(self, inputs, outputs, grads):
5959
(x,) = inputs
@@ -87,7 +87,7 @@ class Erfc(UnaryScalarOp):
8787
nfunc_spec = ("scipy.special.erfc", 1, 1)
8888

8989
def impl(self, x):
90-
return scipy.special.erfc(x)
90+
return special.erfc(x)
9191

9292
def L_op(self, inputs, outputs, grads):
9393
(x,) = inputs
@@ -114,7 +114,7 @@ def c_code(self, node, name, inp, out, sub):
114114
return f"{z} = erfc(({cast}){x});"
115115

116116

117-
# scipy.special.erfc don't support complex. Why?
117+
# special.erfc don't support complex. Why?
118118
erfc = Erfc(upgrade_to_float_no_complex, name="erfc")
119119

120120

@@ -136,7 +136,7 @@ class Erfcx(UnaryScalarOp):
136136
nfunc_spec = ("scipy.special.erfcx", 1, 1)
137137

138138
def impl(self, x):
139-
return scipy.special.erfcx(x)
139+
return special.erfcx(x)
140140

141141
def L_op(self, inputs, outputs, grads):
142142
(x,) = inputs
@@ -192,7 +192,7 @@ class Erfinv(UnaryScalarOp):
192192
nfunc_spec = ("scipy.special.erfinv", 1, 1)
193193

194194
def impl(self, x):
195-
return scipy.special.erfinv(x)
195+
return special.erfinv(x)
196196

197197
def L_op(self, inputs, outputs, grads):
198198
(x,) = inputs
@@ -227,7 +227,7 @@ class Erfcinv(UnaryScalarOp):
227227
nfunc_spec = ("scipy.special.erfcinv", 1, 1)
228228

229229
def impl(self, x):
230-
return scipy.special.erfcinv(x)
230+
return special.erfcinv(x)
231231

232232
def L_op(self, inputs, outputs, grads):
233233
(x,) = inputs
@@ -262,7 +262,7 @@ class Owens_t(BinaryScalarOp):
262262
nfunc_spec = ("scipy.special.owens_t", 2, 1)
263263

264264
def impl(self, h, a):
265-
return scipy.special.owens_t(h, a)
265+
return special.owens_t(h, a)
266266

267267
def grad(self, inputs, grads):
268268
(h, a) = inputs
@@ -287,7 +287,7 @@ class Gamma(UnaryScalarOp):
287287
nfunc_spec = ("scipy.special.gamma", 1, 1)
288288

289289
def impl(self, x):
290-
return scipy.special.gamma(x)
290+
return special.gamma(x)
291291

292292
def L_op(self, inputs, outputs, gout):
293293
(x,) = inputs
@@ -322,7 +322,7 @@ class GammaLn(UnaryScalarOp):
322322
nfunc_spec = ("scipy.special.gammaln", 1, 1)
323323

324324
def impl(self, x):
325-
return scipy.special.gammaln(x)
325+
return special.gammaln(x)
326326

327327
def L_op(self, inputs, outputs, grads):
328328
(x,) = inputs
@@ -362,7 +362,7 @@ class Psi(UnaryScalarOp):
362362
nfunc_spec = ("scipy.special.psi", 1, 1)
363363

364364
def impl(self, x):
365-
return scipy.special.psi(x)
365+
return special.psi(x)
366366

367367
def L_op(self, inputs, outputs, grads):
368368
(x,) = inputs
@@ -449,7 +449,7 @@ class TriGamma(UnaryScalarOp):
449449
"""
450450

451451
def impl(self, x):
452-
return scipy.special.polygamma(1, x)
452+
return special.polygamma(1, x)
453453

454454
def L_op(self, inputs, outputs, outputs_gradients):
455455
(x,) = inputs
@@ -548,7 +548,7 @@ def output_types_preference(n_type, x_type):
548548
return upgrade_to_float_no_complex(x_type)
549549

550550
def impl(self, n, x):
551-
return scipy.special.polygamma(n, x)
551+
return special.polygamma(n, x)
552552

553553
def L_op(self, inputs, outputs, output_gradients):
554554
(n, x) = inputs
@@ -573,10 +573,10 @@ class Chi2SF(BinaryScalarOp):
573573
ie. chi2 pvalue (chi2 'survival function')
574574
"""
575575

576-
nfunc_spec = ("scipy.stats.chi2.sf", 2, 1)
576+
nfunc_spec = ("scipy.special.chdtrc", 2, 1)
577577

578578
def impl(self, x, k):
579-
return scipy.stats.chi2.sf(x, k)
579+
return special.chdtrc(x, k)
580580

581581
def c_support_code(self, **kwargs):
582582
return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8")
@@ -615,7 +615,7 @@ class GammaInc(BinaryScalarOp):
615615
nfunc_spec = ("scipy.special.gammainc", 2, 1)
616616

617617
def impl(self, k, x):
618-
return scipy.special.gammainc(k, x)
618+
return special.gammainc(k, x)
619619

620620
def grad(self, inputs, grads):
621621
(k, x) = inputs
@@ -662,7 +662,7 @@ class GammaIncC(BinaryScalarOp):
662662
nfunc_spec = ("scipy.special.gammaincc", 2, 1)
663663

664664
def impl(self, k, x):
665-
return scipy.special.gammaincc(k, x)
665+
return special.gammaincc(k, x)
666666

667667
def grad(self, inputs, grads):
668668
(k, x) = inputs
@@ -709,7 +709,7 @@ class GammaIncInv(BinaryScalarOp):
709709
nfunc_spec = ("scipy.special.gammaincinv", 2, 1)
710710

711711
def impl(self, k, x):
712-
return scipy.special.gammaincinv(k, x)
712+
return special.gammaincinv(k, x)
713713

714714
def grad(self, inputs, grads):
715715
(k, x) = inputs
@@ -734,7 +734,7 @@ class GammaIncCInv(BinaryScalarOp):
734734
nfunc_spec = ("scipy.special.gammainccinv", 2, 1)
735735

736736
def impl(self, k, x):
737-
return scipy.special.gammainccinv(k, x)
737+
return special.gammainccinv(k, x)
738738

739739
def grad(self, inputs, grads):
740740
(k, x) = inputs
@@ -969,7 +969,7 @@ class GammaU(BinaryScalarOp):
969969
# Note there is no basic SciPy version so no nfunc_spec.
970970

971971
def impl(self, k, x):
972-
return scipy.special.gammaincc(k, x) * scipy.special.gamma(k)
972+
return special.gammaincc(k, x) * special.gamma(k)
973973

974974
def c_support_code(self, **kwargs):
975975
return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8")
@@ -1001,7 +1001,7 @@ class GammaL(BinaryScalarOp):
10011001
# Note there is no basic SciPy version so no nfunc_spec.
10021002

10031003
def impl(self, k, x):
1004-
return scipy.special.gammainc(k, x) * scipy.special.gamma(k)
1004+
return special.gammainc(k, x) * special.gamma(k)
10051005

10061006
def c_support_code(self, **kwargs):
10071007
return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8")
@@ -1033,7 +1033,7 @@ class Jv(BinaryScalarOp):
10331033
nfunc_spec = ("scipy.special.jv", 2, 1)
10341034

10351035
def impl(self, v, x):
1036-
return scipy.special.jv(v, x)
1036+
return special.jv(v, x)
10371037

10381038
def grad(self, inputs, grads):
10391039
v, x = inputs
@@ -1058,7 +1058,7 @@ class J1(UnaryScalarOp):
10581058
nfunc_spec = ("scipy.special.j1", 1, 1)
10591059

10601060
def impl(self, x):
1061-
return scipy.special.j1(x)
1061+
return special.j1(x)
10621062

10631063
def grad(self, inputs, grads):
10641064
(x,) = inputs
@@ -1085,7 +1085,7 @@ class J0(UnaryScalarOp):
10851085
nfunc_spec = ("scipy.special.j0", 1, 1)
10861086

10871087
def impl(self, x):
1088-
return scipy.special.j0(x)
1088+
return special.j0(x)
10891089

10901090
def grad(self, inp, grads):
10911091
(x,) = inp
@@ -1112,7 +1112,7 @@ class Iv(BinaryScalarOp):
11121112
nfunc_spec = ("scipy.special.iv", 2, 1)
11131113

11141114
def impl(self, v, x):
1115-
return scipy.special.iv(v, x)
1115+
return special.iv(v, x)
11161116

11171117
def grad(self, inputs, grads):
11181118
v, x = inputs
@@ -1137,7 +1137,7 @@ class I1(UnaryScalarOp):
11371137
nfunc_spec = ("scipy.special.i1", 1, 1)
11381138

11391139
def impl(self, x):
1140-
return scipy.special.i1(x)
1140+
return special.i1(x)
11411141

11421142
def grad(self, inputs, grads):
11431143
(x,) = inputs
@@ -1159,7 +1159,7 @@ class I0(UnaryScalarOp):
11591159
nfunc_spec = ("scipy.special.i0", 1, 1)
11601160

11611161
def impl(self, x):
1162-
return scipy.special.i0(x)
1162+
return special.i0(x)
11631163

11641164
def grad(self, inp, grads):
11651165
(x,) = inp
@@ -1181,7 +1181,7 @@ class Ive(BinaryScalarOp):
11811181
nfunc_spec = ("scipy.special.ive", 2, 1)
11821182

11831183
def impl(self, v, x):
1184-
return scipy.special.ive(v, x)
1184+
return special.ive(v, x)
11851185

11861186
def grad(self, inputs, grads):
11871187
v, x = inputs
@@ -1206,7 +1206,7 @@ class Kve(BinaryScalarOp):
12061206
nfunc_spec = ("scipy.special.kve", 2, 1)
12071207

12081208
def impl(self, v, x):
1209-
return scipy.special.kve(v, x)
1209+
return special.kve(v, x)
12101210

12111211
def L_op(self, inputs, outputs, output_grads):
12121212
v, x = inputs
@@ -1236,7 +1236,7 @@ class Sigmoid(UnaryScalarOp):
12361236
nfunc_spec = ("scipy.special.expit", 1, 1)
12371237

12381238
def impl(self, x):
1239-
return scipy.special.expit(x)
1239+
return special.expit(x)
12401240

12411241
def grad(self, inp, grads):
12421242
(x,) = inp
@@ -1403,7 +1403,7 @@ class BetaInc(ScalarOp):
14031403
nfunc_spec = ("scipy.special.betainc", 3, 1)
14041404

14051405
def impl(self, a, b, x):
1406-
return scipy.special.betainc(a, b, x)
1406+
return special.betainc(a, b, x)
14071407

14081408
def grad(self, inp, grads):
14091409
a, b, x = inp
@@ -1663,7 +1663,7 @@ class BetaIncInv(ScalarOp):
16631663
nfunc_spec = ("scipy.special.betaincinv", 3, 1)
16641664

16651665
def impl(self, a, b, x):
1666-
return scipy.special.betaincinv(a, b, x)
1666+
return special.betaincinv(a, b, x)
16671667

16681668
def grad(self, inputs, grads):
16691669
(a, b, x) = inputs
@@ -1702,7 +1702,7 @@ class Hyp2F1(ScalarOp):
17021702
nfunc_spec = ("scipy.special.hyp2f1", 4, 1)
17031703

17041704
def impl(self, a, b, c, z):
1705-
return scipy.special.hyp2f1(a, b, c, z)
1705+
return special.hyp2f1(a, b, c, z)
17061706

17071707
def grad(self, inputs, grads):
17081708
a, b, c, z = inputs

0 commit comments

Comments
 (0)