@@ -172,7 +172,6 @@ def jit(
172172 device : xc .Device | None = ...,
173173 backend : str | None = ...,
174174 inline : bool = ...,
175- abstracted_axes : Any | None = ...,
176175 compiler_options : dict [str , Any ] | None = ...,
177176) -> pjit .JitWrapped : ...
178177
@@ -189,7 +188,6 @@ def jit(
189188 device : xc .Device | None = ...,
190189 backend : str | None = ...,
191190 inline : bool = ...,
192- abstracted_axes : Any | None = ...,
193191 compiler_options : dict [str , Any ] | None = ...,
194192) -> Callable [[Callable ], pjit .JitWrapped ]: ...
195193
@@ -205,7 +203,6 @@ def jit(
205203 device : xc .Device | None = None ,
206204 backend : str | None = None ,
207205 inline : bool = False ,
208- abstracted_axes : Any | None = None ,
209206 compiler_options : dict [str , Any ] | None = None ,
210207) -> pjit .JitWrapped | Callable [[Callable ], pjit .JitWrapped ]:
211208 """Sets up ``fun`` for just-in-time compilation with XLA.
@@ -350,8 +347,7 @@ def jit(
350347 static_argnums = static_argnums , static_argnames = static_argnames ,
351348 donate_argnums = donate_argnums , donate_argnames = donate_argnames ,
352349 keep_unused = keep_unused , device = device , backend = backend , inline = inline ,
353- abstracted_axes = abstracted_axes , compiler_options = compiler_options ,
354- use_resource_env = False )
350+ compiler_options = compiler_options , use_resource_env = False )
355351 if isinstance (fun , NotSpecified ):
356352 return lambda fun : pjit .make_jit (fun , ** kwds )
357353 else :
@@ -2563,13 +2559,13 @@ def transposed_fun(const, out_cotangent):
25632559 return Partial (transposed_fun , const )
25642560
25652561
2566- def _flat_axes_specs (abstracted_axes , * args , ** kwargs
2562+ def _flat_axes_specs (* args , ** kwargs
25672563 ) -> list [pe .AbstractedAxesSpec ]:
25682564 if kwargs : raise NotImplementedError
25692565 def ax_leaf (l ):
25702566 return (isinstance (l , dict ) and all_leaves (l .values ()) or
25712567 isinstance (l , tuple ) and all_leaves (l , lambda x : x is None ))
2572- return broadcast_prefix (abstracted_axes , args , ax_leaf )
2568+ return broadcast_prefix (args , ax_leaf )
25732569
25742570
25752571@overload
@@ -2578,7 +2574,6 @@ def make_jaxpr(
25782574 static_argnums : int | Iterable [int ] = (),
25792575 axis_env : Sequence [tuple [AxisName , int ]] | None = None ,
25802576 return_shape : Literal [False ] = ...,
2581- abstracted_axes : Any | None = None ,
25822577) -> Callable [..., core .ClosedJaxpr ]:
25832578 ...
25842579
@@ -2588,7 +2583,6 @@ def make_jaxpr(
25882583 static_argnums : int | Iterable [int ] = (),
25892584 axis_env : Sequence [tuple [AxisName , int ]] | None = None ,
25902585 return_shape : Literal [True ] = ...,
2591- abstracted_axes : Any | None = None ,
25922586) -> Callable [..., tuple [core .ClosedJaxpr , Any ]]:
25932587 ...
25942588
@@ -2598,7 +2592,6 @@ def make_jaxpr(
25982592 static_argnums : int | Iterable [int ] = (),
25992593 axis_env : Sequence [tuple [AxisName , int ]] | None = None ,
26002594 return_shape : bool = False ,
2601- abstracted_axes : Any | None = None ,
26022595) -> Callable [..., core .ClosedJaxpr | tuple [core .ClosedJaxpr , Any ]]:
26032596 """Create a function that returns the jaxpr of ``fun`` given example args.
26042597
@@ -2666,8 +2659,7 @@ def make_jaxpr(
26662659 @api_boundary
26672660 def make_jaxpr_f (* args , ** kwargs ):
26682661 with core .extend_axis_env_nd (axis_env or []):
2669- traced = jit (fun , static_argnums = static_argnums ,
2670- abstracted_axes = abstracted_axes ).trace (* args , ** kwargs )
2662+ traced = jit (fun , static_argnums = static_argnums ).trace (* args , ** kwargs )
26712663 # `jit` converts tracers in consts to args but `make_jaxpr` callers expect
26722664 # consts not to be converted.
26732665 num_consts = traced ._num_consts
0 commit comments