@@ -2212,106 +2212,12 @@ def vjp(
22122212 fun , debug_info = debug_info ("vjp" , fun , primals , {}))
22132213 return _vjp (wrapped_fun , * primals , has_aux = has_aux )
22142214
2215- @partial (api_boundary , repro_api_name = "jax.experimental.saved_input_vjp" )
2216- def saved_input_vjp (f : Callable , which : Sequence [bool ], * primals ,
2217- allow_unused : bool = True , allow_opaque : bool = True ):
2218- if len (which ) != len (primals ):
2219- raise ValueError (
2220- "length of 'which' argument must equal the number of primal input values, "
2221- f"but got { len (which )= } and { len (primals )= } " )
2222-
2223- dbg = debug_info ("saved_input_vjp" , f , primals , {})
2224- fun = lu .wrap_init (f , debug_info = dbg )
2225- primals_flat , in_tree = tree_flatten (primals )
2226- fun , out_tree = flatten_fun_nokwargs (fun , in_tree )
2227- out_primals_flat , out_pvals , jaxpr , residuals = ad .linearize (fun , * primals_flat )
2228- out_known = [pval .is_known () for pval in out_pvals ]
2229- primals_filt , filt_tree = tree_flatten (tuple (p for w , p in zip (which , primals ) if w ))
2230- id_map = {id (x ): i for i , x in enumerate (primals_filt )}
2231- opaque_residuals = []
2232- res_spec = [RSpec (id_map [id (r )], True ) if id (r ) in id_map else
2233- RSpec (opaque_residuals .append (r ) or (len (opaque_residuals ) - 1 ), False ) # type: ignore
2234- for r in residuals ]
2235- out_primal_avals = map (shaped_abstractify , out_primals_flat )
2236- f_vjp = Partial (partial (_saved_input_vjpfun , res_spec , filt_tree , in_tree ,
2237- out_tree (), out_known , jaxpr , out_primal_avals ),
2238- opaque_residuals )
2239-
2240- if not allow_unused and not set (id_map ).issubset (res_ids := {id (r ) for r in residuals }):
2241- unused = [(i , core .get_aval (x )) for i , (x , w ) in enumerate (zip (primals , which ))
2242- if w and id (x ) not in res_ids ]
2243- assert unused
2244- if len (unused ) == 1 :
2245- (i , a ), = unused
2246- start , was = "an input value" , "was"
2247- msg = f" { dbg .arg_names [i ] if dbg .arg_names is not None else 'unknown' } of type { a .str_short ()} "
2248- else :
2249- start , was = "multiple input values" , "were"
2250- msg = "\n " + "\n " .join (f" * { dbg .arg_names [i ] if dbg .arg_names is not None else 'unknown' } of type { a .str_short ()} "
2251- for i , a in unused )
2252- raise Exception (f"with { allow_unused = } , { start } marked to be saved { was } "
2253- f"not used by the backward pass:{ msg } " )
2254-
2255- if not allow_opaque and opaque_residuals :
2256- msg = ", " .join (core .get_aval (x ).str_short () for x in opaque_residuals )
2257- raise Exception (f"with { allow_opaque = } , the backward pass requires opaque "
2258- f"(non-input) residuals: { msg } " )
2259-
2260- out_primals = tree_unflatten (out_tree (), out_primals_flat )
2261- return out_primals , f_vjp
2262-
2263- def _saved_input_vjpfun (res_spec , filtered_tree , in_tree , out_tree , out_known ,
2264- jaxpr , out_primal_avals , opaque_residuals , ct ,
2265- * saved_primals ):
2266- primals_filtered , filtered_tree_ = tree_flatten (saved_primals )
2267- if filtered_tree != filtered_tree_ :
2268- raise ValueError (
2269- "inputs passed to f_vjp must be a tuple of (pytrees of) "
2270- "arrays with the same structure as\n "
2271- " tuple(x for x, w in zip(inputs, which) if w)\n "
2272- "given the original call\n "
2273- " _, f_vjp = saved_input_vjp(f, which, *inputs, ...)\n "
2274- "but the structures differ:\n " +
2275- "\n " .join (f" * inputs{ keystr (path )} was a { thing1 } in the original "
2276- f"call, but a { thing2 } here, so { explanation } "
2277- for path , thing1 , thing2 , explanation
2278- in equality_errors_pytreedef (filtered_tree , filtered_tree_ )))
2279-
2280- residuals = [primals_filtered [i .idx ] if i .primal else opaque_residuals [i .idx ]
2281- for i in res_spec ]
2282- dummy_args = [ad .UndefinedPrimal (v .aval ) for v in jaxpr .invars ]
2283- cts_flat , out_tree_ = tree_flatten (ct )
2284- if out_tree_ != out_tree :
2285- raise ValueError (f"unexpected tree structure of argument to vjp function: "
2286- f"got { out_tree_ } , but expected to match { out_tree } " )
2287- for arg , aval in zip (cts_flat , out_primal_avals ):
2288- ct_aval = shaped_abstractify (arg )
2289- ct_aval_expected = aval .to_cotangent_aval ()
2290- if (not core .typecompat (ct_aval , ct_aval_expected ) and
2291- not _temporary_dtype_exception (ct_aval , ct_aval_expected )):
2292- raise ValueError (
2293- "unexpected JAX type (e.g. shape/dtype) for argument to vjp function: "
2294- f"got { ct_aval .str_short ()} , but expected { ct_aval_expected .str_short ()} "
2295- f"because the corresponding output of the function had JAX type "
2296- f"{ aval .str_short ()} " )
2297-
2298- cts_flat = [ct for ct , k in zip (cts_flat , out_known ) if not k ]
2299- arg_cts = ad .backward_pass (jaxpr , True , residuals , dummy_args , cts_flat )
2300- return tree_unflatten (in_tree , map (ad .instantiate_zeros , arg_cts ))
2301-
2302- @dataclasses .dataclass (frozen = True )
2303- class RSpec :
2304- idx : int
2305- primal : bool
2306-
2307- si_vjp = saved_input_vjp
2308-
2309-
23102215def _vjp (fun , * primals , has_aux = False ):
23112216 canon = lambda x : x if isinstance (x , core .Tracer ) else canonicalize_value (x )
23122217 primals = tree_map (canon , primals )
23132218 primals_flat , in_tree = tree_flatten (primals )
2314- for arg in primals_flat : dispatch .check_arg (arg )
2219+ for arg in primals_flat :
2220+ dispatch .check_arg (arg )
23152221 if not has_aux :
23162222 flat_fun , out_tree = flatten_fun_nokwargs (fun , in_tree )
23172223 out_primals_flat , out_pvals , jaxpr , residuals = ad .linearize (
@@ -2340,22 +2246,14 @@ def _vjp(fun, *primals, has_aux=False):
23402246 else :
23412247 return out_primals , f_vjp , tree_unflatten (aux_tree , aux )
23422248
2343- def tuptree_map (f , treedef , x ):
2344- return treedef .walk (lambda xs , _ : tuple (xs ), f , x )
2345-
2346-
2347- def _is_ref (x ):
2348- from jax ._src .state .types import AbstractRef
2349- try : return isinstance (typeof (x ), AbstractRef )
2350- except : return False
2351-
23522249def _vjp3_callable (spec , out_known , jaxpr , out_primal_avals , in_tree , out_tree ,
23532250 args_res , opaque_res , * maybe_ct_refs ):
23542251 if not maybe_ct_refs :
23552252 maybe_ct_refs_flat = [GradValue ()] * in_tree .num_leaves
23562253 else :
23572254 maybe_ct_refs_flat , in_tree_ = tree_flatten (maybe_ct_refs )
2358- if in_tree != in_tree_ : raise Exception # TODO accept isomorph tuple tree
2255+ if in_tree != in_tree_ :
2256+ raise Exception # TODO accept isomorph tuple tree
23592257 args_res_ = tree_leaves (args_res , is_leaf = lambda x : isinstance (x , NotNeeded ))
23602258 residuals = [args_res_ [i .idx ] if i .primal else opaque_res [i .idx ] for i in spec ]
23612259 maybe_refs = [ad .RefAccum (v .aval , x ) if _is_ref (x ) else ad .ValAccum (v .aval )
@@ -2366,7 +2264,8 @@ def _vjp3_callable(spec, out_known, jaxpr, out_primal_avals, in_tree, out_tree,
23662264def _vjp3_bwd (in_tree , out_tree , out_known , jaxpr , out_primal_avals , residuals ,
23672265 maybe_refs , out_ct ):
23682266 cts_flat , out_tree_ = tree_flatten (out_ct )
2369- if out_tree != out_tree_ : _vjp_ct_tree_error (jaxpr , out_tree , out_tree_ )
2267+ if out_tree != out_tree_ :
2268+ _vjp_ct_tree_error (jaxpr , out_tree , out_tree_ )
23702269 _vjp_check_ct_avals (cts_flat , out_primal_avals )
23712270 cts_flat = [ct for ct , k in zip (cts_flat , out_known ) if not k ]
23722271 ad .backward_pass3 (jaxpr , True , residuals , maybe_refs , cts_flat )
@@ -2375,6 +2274,23 @@ def _vjp3_bwd(in_tree, out_tree, out_known, jaxpr, out_primal_avals, residuals,
23752274 arg_cts = map (ad .instantiate_zeros , arg_cts )
23762275 return tree_unflatten (in_tree , arg_cts )
23772276
2277+
2278+ @dataclasses .dataclass (frozen = True )
2279+ class RSpec :
2280+ idx : int
2281+ primal : bool
2282+
2283+ def tuptree_map (f , treedef , x ):
2284+ return treedef .walk (lambda xs , _ : tuple (xs ), f , x )
2285+
2286+ def _is_ref (x ):
2287+ from jax ._src .state .types import AbstractRef
2288+ try :
2289+ return isinstance (typeof (x ), AbstractRef )
2290+ except :
2291+ return False
2292+
2293+
23782294_vjp_too_many_args = """
23792295The function returned by `jax.vjp` applied to {} was called with {} arguments,
23802296but functions returned by `jax.vjp` must be called with a single argument
@@ -2396,6 +2312,7 @@ def f(x):
23962312arguments rather than in a tuple, this error can arise.
23972313""" .format
23982314
2315+
23992316def _vjp_ct_tree_error (jaxpr , out_tree , ct_tree ):
24002317 msg = f"""unexpected tree structure.
24012318
@@ -2410,6 +2327,7 @@ def _vjp_ct_tree_error(jaxpr, out_tree, ct_tree):
24102327 in equality_errors_pytreedef (out_tree , ct_tree ))
24112328 raise ValueError (msg )
24122329
2330+
24132331def _vjp_check_ct_avals (cts , primal_avals ):
24142332 # TODO(mattjj): improve this error by flattening with keys in the first place
24152333 for ct , aval in zip (cts , primal_avals ):
@@ -2425,6 +2343,7 @@ def _vjp_check_ct_avals(cts, primal_avals):
24252343 "because the corresponding output of the differentiated function had JAX type "
24262344 f"{ aval .str_short ()} " )
24272345
2346+
24282347@register_dataclass
24292348@dataclasses .dataclass (frozen = True )
24302349class NotNeeded :
0 commit comments