Skip to content

Commit b4a3e16

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Delete si_vjp from JAX. Dead weight at this point. It is now replaced with jax.vjp.
PiperOrigin-RevId: 844865984
1 parent a8983f8 commit b4a3e16

File tree

4 files changed

+54
-170
lines changed

4 files changed

+54
-170
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
3737
For the moment, during Python type checking, we continue to declare `Tracer`
3838
as a subclass of `Array`, however we expect to remove this in a future
3939
release.
40+
* `jax.experimental.si_vjp` has been deleted.
41+
`jax.vjp` subsumes it's functionality.
4042

4143
## JAX 0.8.1 (November 18, 2025)
4244

jax/_src/api.py

Lines changed: 26 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -2212,106 +2212,12 @@ def vjp(
22122212
fun, debug_info=debug_info("vjp", fun, primals, {}))
22132213
return _vjp(wrapped_fun, *primals, has_aux=has_aux)
22142214

2215-
@partial(api_boundary, repro_api_name="jax.experimental.saved_input_vjp")
2216-
def saved_input_vjp(f: Callable, which: Sequence[bool], *primals,
2217-
allow_unused: bool = True, allow_opaque: bool = True):
2218-
if len(which) != len(primals):
2219-
raise ValueError(
2220-
"length of 'which' argument must equal the number of primal input values, "
2221-
f"but got {len(which)=} and {len(primals)=}")
2222-
2223-
dbg = debug_info("saved_input_vjp", f, primals, {})
2224-
fun = lu.wrap_init(f, debug_info=dbg)
2225-
primals_flat, in_tree = tree_flatten(primals)
2226-
fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
2227-
out_primals_flat, out_pvals, jaxpr, residuals = ad.linearize(fun, *primals_flat)
2228-
out_known = [pval.is_known() for pval in out_pvals]
2229-
primals_filt, filt_tree = tree_flatten(tuple(p for w, p in zip(which, primals) if w))
2230-
id_map = {id(x): i for i, x in enumerate(primals_filt)}
2231-
opaque_residuals = []
2232-
res_spec = [RSpec(id_map[id(r)], True) if id(r) in id_map else
2233-
RSpec(opaque_residuals.append(r) or (len(opaque_residuals) - 1), False) # type: ignore
2234-
for r in residuals]
2235-
out_primal_avals = map(shaped_abstractify, out_primals_flat)
2236-
f_vjp = Partial(partial(_saved_input_vjpfun, res_spec, filt_tree, in_tree,
2237-
out_tree(), out_known, jaxpr, out_primal_avals),
2238-
opaque_residuals)
2239-
2240-
if not allow_unused and not set(id_map).issubset(res_ids := {id(r) for r in residuals}):
2241-
unused = [(i, core.get_aval(x)) for i, (x, w) in enumerate(zip(primals, which))
2242-
if w and id(x) not in res_ids]
2243-
assert unused
2244-
if len(unused) == 1:
2245-
(i, a), = unused
2246-
start, was = "an input value", "was"
2247-
msg = f" {dbg.arg_names[i] if dbg.arg_names is not None else 'unknown'} of type {a.str_short()}"
2248-
else:
2249-
start, was = "multiple input values", "were"
2250-
msg = "\n" + "\n".join(f" * {dbg.arg_names[i] if dbg.arg_names is not None else 'unknown'} of type {a.str_short()}"
2251-
for i, a in unused)
2252-
raise Exception(f"with {allow_unused=}, {start} marked to be saved {was} "
2253-
f"not used by the backward pass:{msg}")
2254-
2255-
if not allow_opaque and opaque_residuals:
2256-
msg = ", ".join(core.get_aval(x).str_short() for x in opaque_residuals)
2257-
raise Exception(f"with {allow_opaque=}, the backward pass requires opaque "
2258-
f"(non-input) residuals: {msg}")
2259-
2260-
out_primals = tree_unflatten(out_tree(), out_primals_flat)
2261-
return out_primals, f_vjp
2262-
2263-
def _saved_input_vjpfun(res_spec, filtered_tree, in_tree, out_tree, out_known,
2264-
jaxpr, out_primal_avals, opaque_residuals, ct,
2265-
*saved_primals):
2266-
primals_filtered, filtered_tree_ = tree_flatten(saved_primals)
2267-
if filtered_tree != filtered_tree_:
2268-
raise ValueError(
2269-
"inputs passed to f_vjp must be a tuple of (pytrees of) "
2270-
"arrays with the same structure as\n"
2271-
" tuple(x for x, w in zip(inputs, which) if w)\n"
2272-
"given the original call\n"
2273-
" _, f_vjp = saved_input_vjp(f, which, *inputs, ...)\n"
2274-
"but the structures differ:\n" +
2275-
"\n".join(f" * inputs{keystr(path)} was a {thing1} in the original "
2276-
f"call, but a {thing2} here, so {explanation}"
2277-
for path, thing1, thing2, explanation
2278-
in equality_errors_pytreedef(filtered_tree, filtered_tree_)))
2279-
2280-
residuals = [primals_filtered[i.idx] if i.primal else opaque_residuals[i.idx]
2281-
for i in res_spec]
2282-
dummy_args = [ad.UndefinedPrimal(v.aval) for v in jaxpr.invars]
2283-
cts_flat, out_tree_ = tree_flatten(ct)
2284-
if out_tree_ != out_tree:
2285-
raise ValueError(f"unexpected tree structure of argument to vjp function: "
2286-
f"got {out_tree_}, but expected to match {out_tree}")
2287-
for arg, aval in zip(cts_flat, out_primal_avals):
2288-
ct_aval = shaped_abstractify(arg)
2289-
ct_aval_expected = aval.to_cotangent_aval()
2290-
if (not core.typecompat(ct_aval, ct_aval_expected) and
2291-
not _temporary_dtype_exception(ct_aval, ct_aval_expected)):
2292-
raise ValueError(
2293-
"unexpected JAX type (e.g. shape/dtype) for argument to vjp function: "
2294-
f"got {ct_aval.str_short()}, but expected {ct_aval_expected.str_short()} "
2295-
f"because the corresponding output of the function had JAX type "
2296-
f"{aval.str_short()}")
2297-
2298-
cts_flat = [ct for ct, k in zip(cts_flat, out_known) if not k]
2299-
arg_cts = ad.backward_pass(jaxpr, True, residuals, dummy_args, cts_flat)
2300-
return tree_unflatten(in_tree, map(ad.instantiate_zeros, arg_cts))
2301-
2302-
@dataclasses.dataclass(frozen=True)
2303-
class RSpec:
2304-
idx: int
2305-
primal: bool
2306-
2307-
si_vjp = saved_input_vjp
2308-
2309-
23102215
def _vjp(fun, *primals, has_aux=False):
23112216
canon = lambda x: x if isinstance(x, core.Tracer) else canonicalize_value(x)
23122217
primals = tree_map(canon, primals)
23132218
primals_flat, in_tree = tree_flatten(primals)
2314-
for arg in primals_flat: dispatch.check_arg(arg)
2219+
for arg in primals_flat:
2220+
dispatch.check_arg(arg)
23152221
if not has_aux:
23162222
flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
23172223
out_primals_flat, out_pvals, jaxpr, residuals = ad.linearize(
@@ -2340,22 +2246,14 @@ def _vjp(fun, *primals, has_aux=False):
23402246
else:
23412247
return out_primals, f_vjp, tree_unflatten(aux_tree, aux)
23422248

2343-
def tuptree_map(f, treedef, x):
2344-
return treedef.walk(lambda xs, _: tuple(xs), f, x)
2345-
2346-
2347-
def _is_ref(x):
2348-
from jax._src.state.types import AbstractRef
2349-
try: return isinstance(typeof(x), AbstractRef)
2350-
except: return False
2351-
23522249
def _vjp3_callable(spec, out_known, jaxpr, out_primal_avals, in_tree, out_tree,
23532250
args_res, opaque_res, *maybe_ct_refs):
23542251
if not maybe_ct_refs:
23552252
maybe_ct_refs_flat = [GradValue()] * in_tree.num_leaves
23562253
else:
23572254
maybe_ct_refs_flat, in_tree_ = tree_flatten(maybe_ct_refs)
2358-
if in_tree != in_tree_: raise Exception # TODO accept isomorph tuple tree
2255+
if in_tree != in_tree_:
2256+
raise Exception # TODO accept isomorph tuple tree
23592257
args_res_ = tree_leaves(args_res, is_leaf=lambda x: isinstance(x, NotNeeded))
23602258
residuals = [args_res_[i.idx] if i.primal else opaque_res[i.idx] for i in spec]
23612259
maybe_refs = [ad.RefAccum(v.aval, x) if _is_ref(x) else ad.ValAccum(v.aval)
@@ -2366,7 +2264,8 @@ def _vjp3_callable(spec, out_known, jaxpr, out_primal_avals, in_tree, out_tree,
23662264
def _vjp3_bwd(in_tree, out_tree, out_known, jaxpr, out_primal_avals, residuals,
23672265
maybe_refs, out_ct):
23682266
cts_flat, out_tree_ = tree_flatten(out_ct)
2369-
if out_tree != out_tree_: _vjp_ct_tree_error(jaxpr, out_tree, out_tree_)
2267+
if out_tree != out_tree_:
2268+
_vjp_ct_tree_error(jaxpr, out_tree, out_tree_)
23702269
_vjp_check_ct_avals(cts_flat, out_primal_avals)
23712270
cts_flat = [ct for ct, k in zip(cts_flat, out_known) if not k]
23722271
ad.backward_pass3(jaxpr, True, residuals, maybe_refs, cts_flat)
@@ -2375,6 +2274,23 @@ def _vjp3_bwd(in_tree, out_tree, out_known, jaxpr, out_primal_avals, residuals,
23752274
arg_cts = map(ad.instantiate_zeros, arg_cts)
23762275
return tree_unflatten(in_tree, arg_cts)
23772276

2277+
2278+
@dataclasses.dataclass(frozen=True)
2279+
class RSpec:
2280+
idx: int
2281+
primal: bool
2282+
2283+
def tuptree_map(f, treedef, x):
2284+
return treedef.walk(lambda xs, _: tuple(xs), f, x)
2285+
2286+
def _is_ref(x):
2287+
from jax._src.state.types import AbstractRef
2288+
try:
2289+
return isinstance(typeof(x), AbstractRef)
2290+
except:
2291+
return False
2292+
2293+
23782294
_vjp_too_many_args = """
23792295
The function returned by `jax.vjp` applied to {} was called with {} arguments,
23802296
but functions returned by `jax.vjp` must be called with a single argument
@@ -2396,6 +2312,7 @@ def f(x):
23962312
arguments rather than in a tuple, this error can arise.
23972313
""".format
23982314

2315+
23992316
def _vjp_ct_tree_error(jaxpr, out_tree, ct_tree):
24002317
msg = f"""unexpected tree structure.
24012318
@@ -2410,6 +2327,7 @@ def _vjp_ct_tree_error(jaxpr, out_tree, ct_tree):
24102327
in equality_errors_pytreedef(out_tree, ct_tree))
24112328
raise ValueError(msg)
24122329

2330+
24132331
def _vjp_check_ct_avals(cts, primal_avals):
24142332
# TODO(mattjj): improve this error by flattening with keys in the first place
24152333
for ct, aval in zip(cts, primal_avals):
@@ -2425,6 +2343,7 @@ def _vjp_check_ct_avals(cts, primal_avals):
24252343
"because the corresponding output of the differentiated function had JAX type "
24262344
f"{aval.str_short()}")
24272345

2346+
24282347
@register_dataclass
24292348
@dataclasses.dataclass(frozen=True)
24302349
class NotNeeded:

jax/experimental/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,6 @@
2121
# experimental features and as a result, more flexibility to manage their status
2222
# and lifetimes.
2323

24-
from jax._src.api import (
25-
saved_input_vjp as saved_input_vjp,
26-
si_vjp as si_vjp
27-
)
2824
from jax._src.callback import (
2925
io_callback as io_callback
3026
)

tests/api_test.py

Lines changed: 26 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -7769,9 +7769,12 @@ def test_basic(self):
77697769
def f(x, y):
77707770
return x * y
77717771

7772-
primals = 2., 3.
7773-
y, f_vjp = api.si_vjp(f, [True, True], *primals)
7774-
arg_cts = f_vjp(1., *primals)
7772+
primals = [2., 3.]
7773+
y, f_vjp = jax.vjp(f, *primals)
7774+
f_vjp.args_res = [None, None]
7775+
y_grad = 1.
7776+
f_vjp.args_res = primals
7777+
arg_cts = f_vjp(1.)
77757778
self.assertAllClose(y, 6.)
77767779
self.assertAllClose(arg_cts, (3., 2.))
77777780

@@ -7782,29 +7785,20 @@ def f(x, y):
77827785
@jax.jit
77837786
def g():
77847787
primals = 2., 3.
7785-
y, f_vjp = api.si_vjp(f, [True, True], *primals)
7788+
y, f_vjp = jax.vjp(f, *primals)
7789+
f_vjp.args_res = [None, None]
77867790
return y, f_vjp
77877791

77887792
@jax.jit
77897793
def h(f_vjp):
7790-
return f_vjp(1., 2., 3.)
7794+
f_vjp.args_res = [2., 3.]
7795+
return f_vjp(1.)
77917796

77927797
y, f_vjp = g()
77937798
arg_cts = h(f_vjp)
77947799
self.assertAllClose(y, 6.)
77957800
self.assertAllClose(arg_cts, (3., 2.))
77967801

7797-
def test_basic_unused(self):
7798-
f = jnp.sin
7799-
primals = 3.,
7800-
y, f_vjp = api.si_vjp(f, [True], *primals)
7801-
x_ct, = f_vjp(1., *primals)
7802-
self.assertAllClose(y, jnp.sin(3.))
7803-
self.assertAllClose(x_ct, jnp.cos(3.))
7804-
7805-
with self.assertRaisesRegex(Exception, "not used by the backward pass: x"):
7806-
_ = api.si_vjp(f, [True], *primals, allow_unused=False)
7807-
78087802
def test_basic_unused_vjp3(self):
78097803
f = jnp.sin
78107804
primals = 3.,
@@ -7814,58 +7808,28 @@ def test_basic_unused_vjp3(self):
78147808
self.assertAllClose(x_ct, jnp.cos(3.))
78157809
self.assertIsInstance(f_vjp.args_res[0], api.NotNeeded) # can check if unused
78167810

7817-
def test_basic_opaque(self):
7818-
f = jnp.sin
7819-
primals = 3.,
7820-
with self.assertRaisesRegex(Exception, "the backward pass requires opaque"):
7821-
_ = api.si_vjp(f, [True], *primals, allow_opaque=False)
7822-
78237811
def test_basic_opaque_vjp3(self):
78247812
f = jnp.sin
78257813
primals = 3.,
78267814
_, f_vjp = api.vjp(f, *primals)
7827-
assert f_vjp.opaque_residuals # can detect if opaque res are used
7815+
self.assertTrue(f_vjp.opaque_residuals) # can detect if opaque res are used
78287816

78297817
def test_basic_pytree_error(self):
78307818
def f(x):
78317819
return [x['hi'] * x['bye']]
78327820

7833-
y, f_vjp = api.si_vjp(f, [True], {'hi': 2., 'bye': 3.})
7834-
arg_ct, = f_vjp([1.], {'hi': 2., 'bye': 3.})
7821+
y, f_vjp = jax.vjp(f, {'hi': 2., 'bye': 3.})
7822+
f_vjp.args_res = [None]
7823+
y_grad = [1.]
7824+
f_vjp.args_res = [{'hi': 2., 'bye': 3.}]
7825+
arg_ct, = f_vjp(y_grad)
78357826
self.assertAllClose(y, [6.])
78367827
self.assertAllClose(arg_ct, {'hi': 3., 'bye': 2.})
78377828

7838-
with self.assertRaisesRegex(ValueError, "but the structures differ"):
7839-
f_vjp(1., {'hi': 2.})
7840-
7841-
# TODO(mattjj): improve this vjp3 error message
7842-
# def test_basic_pytree_error_vjp3(self):
7843-
# def f(x):
7844-
# return [x['hi'] * x['bye']]
7845-
7846-
# y, f_vjp = api.vjp(f, {'hi': 2., 'bye': 3.})
7847-
# arg_ct, = f_vjp([1.], {'hi': 2., 'bye': 3.})
7848-
# self.assertAllClose(y, [6.])
7849-
# self.assertAllClose(arg_ct, {'hi': 3., 'bye': 2.})
7850-
7851-
# f_vjp.args_res[0] = {'hi': 2.}
7852-
# with self.assertRaisesRegex(ValueError, "but the structures differ"):
7853-
# f_vjp(1.)
7854-
7855-
def test_fsdp(self):
7856-
# see https://github.com/jax-ml/jax/pull/27017 for why this is called "fsdp"
7857-
def f2(x, w):
7858-
x = 1. * x
7859-
x = x @ w
7860-
x = 2. * x
7861-
return x
7862-
7863-
x = jnp.ones((3, 4))
7864-
w = jnp.ones((4, 4))
7865-
y, f2_sivjp = api.si_vjp(f2, [False, True], x, w)
7866-
y_grad = jnp.ones_like(y)
7867-
x_grad, w_grad = f2_sivjp(y_grad, w)
7868-
self.assertAllClose(x_grad, 2. * y_grad @ w.T)
7829+
# TODO(mattjj): Raise an error message.
7830+
# with self.assertRaisesRegex(ValueError, "but the structures differ"):
7831+
# f_vjp.args_res = [{'hi': 2.}]
7832+
# f_vjp([1.])
78697833

78707834
def test_fsdp_error(self):
78717835
# see https://github.com/jax-ml/jax/pull/27017 for why this is called "fsdp"
@@ -7877,10 +7841,12 @@ def f2(x, w):
78777841

78787842
x = jnp.ones((3, 4))
78797843
w = jnp.ones((4, 4))
7880-
y, f2_sivjp = api.si_vjp(f2, [False, True], x, w)
7844+
y, f2_vjp = jax.vjp(f2, x, w)
7845+
f2_vjp.args_res[1] = None
78817846
y_grad = jnp.ones((2, 4))
7847+
f2_vjp.args_res[1] = w
78827848
with self.assertRaisesRegex(ValueError, "unexpected JAX type"):
7883-
f2_sivjp(y_grad, w)
7849+
f2_vjp(y_grad)
78847850

78857851
def test_fsdp_vjp3(self):
78867852
# see https://github.com/jax-ml/jax/pull/27017 for why this is called "fsdp"
@@ -7902,10 +7868,11 @@ def f2(x, w):
79027868
self.assertAllClose(w_grad, 2. * x.T @ y_grad)
79037869

79047870
def test_doesnt_leak_symbolic_zeros(self):
7905-
_, vjp = api.si_vjp(lambda x: 1., [False], 3.14)
7871+
_, vjp = jax.vjp(lambda x: 1., 3.14)
79067872
ans, = vjp(1.0)
79077873
self.assertIsInstance(ans, jax.Array)
79087874

7875+
79097876
class TracebackTest(jtu.JaxTestCase):
79107877
# These tests are to catch regressions in Python traceback sizes. Our
79117878
# second-order APIs can be nested arbitrarily and if each one adds a dozen

0 commit comments

Comments
 (0)