Skip to content

Commit 75697ef

Browse files
dougalmmattjj
authored andcommitted
Remove dynamic shapes. Dead weight at this point.
1 parent eeab9e4 commit 75697ef

40 files changed

+261
-4981
lines changed

ci/run_bazel_test_tpu.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,6 @@ else
188188
//tests/pallas:tpu_pallas_call_print_test_tpu \
189189
//tests/pallas:indexing_test_tpu \
190190
//tests/pallas:pallas_error_handling_test_tpu \
191-
//tests/pallas:pallas_jumble_test_tpu \
192191
//tests/pallas:pallas_shape_poly_test_tpu \
193192
//tests/pallas:tpu_all_gather_test_tpu \
194193
//tests/pallas:tpu_fusible_matmul_test_tpu \

jax/_src/api.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ def jit(
172172
device: xc.Device | None = ...,
173173
backend: str | None = ...,
174174
inline: bool = ...,
175-
abstracted_axes: Any | None = ...,
176175
compiler_options: dict[str, Any] | None = ...,
177176
) -> pjit.JitWrapped: ...
178177

@@ -189,7 +188,6 @@ def jit(
189188
device: xc.Device | None = ...,
190189
backend: str | None = ...,
191190
inline: bool = ...,
192-
abstracted_axes: Any | None = ...,
193191
compiler_options: dict[str, Any] | None = ...,
194192
) -> Callable[[Callable], pjit.JitWrapped]: ...
195193

@@ -205,7 +203,6 @@ def jit(
205203
device: xc.Device | None = None,
206204
backend: str | None = None,
207205
inline: bool = False,
208-
abstracted_axes: Any | None = None,
209206
compiler_options: dict[str, Any] | None = None,
210207
) -> pjit.JitWrapped | Callable[[Callable], pjit.JitWrapped]:
211208
"""Sets up ``fun`` for just-in-time compilation with XLA.
@@ -350,8 +347,7 @@ def jit(
350347
static_argnums=static_argnums, static_argnames=static_argnames,
351348
donate_argnums=donate_argnums, donate_argnames=donate_argnames,
352349
keep_unused=keep_unused, device=device, backend=backend, inline=inline,
353-
abstracted_axes=abstracted_axes, compiler_options=compiler_options,
354-
use_resource_env=False)
350+
compiler_options=compiler_options, use_resource_env=False)
355351
if isinstance(fun, NotSpecified):
356352
return lambda fun: pjit.make_jit(fun, **kwds)
357353
else:
@@ -2563,13 +2559,13 @@ def transposed_fun(const, out_cotangent):
25632559
return Partial(transposed_fun, const)
25642560

25652561

2566-
def _flat_axes_specs(abstracted_axes, *args, **kwargs
2562+
def _flat_axes_specs(*args, **kwargs
25672563
) -> list[pe.AbstractedAxesSpec]:
25682564
if kwargs: raise NotImplementedError
25692565
def ax_leaf(l):
25702566
return (isinstance(l, dict) and all_leaves(l.values()) or
25712567
isinstance(l, tuple) and all_leaves(l, lambda x: x is None))
2572-
return broadcast_prefix(abstracted_axes, args, ax_leaf)
2568+
return broadcast_prefix(args, ax_leaf)
25732569

25742570

25752571
@overload
@@ -2578,7 +2574,6 @@ def make_jaxpr(
25782574
static_argnums: int | Iterable[int] = (),
25792575
axis_env: Sequence[tuple[AxisName, int]] | None = None,
25802576
return_shape: Literal[False] = ...,
2581-
abstracted_axes: Any | None = None,
25822577
) -> Callable[..., core.ClosedJaxpr]:
25832578
...
25842579

@@ -2588,7 +2583,6 @@ def make_jaxpr(
25882583
static_argnums: int | Iterable[int] = (),
25892584
axis_env: Sequence[tuple[AxisName, int]] | None = None,
25902585
return_shape: Literal[True] = ...,
2591-
abstracted_axes: Any | None = None,
25922586
) -> Callable[..., tuple[core.ClosedJaxpr, Any]]:
25932587
...
25942588

@@ -2598,7 +2592,6 @@ def make_jaxpr(
25982592
static_argnums: int | Iterable[int] = (),
25992593
axis_env: Sequence[tuple[AxisName, int]] | None = None,
26002594
return_shape: bool = False,
2601-
abstracted_axes: Any | None = None,
26022595
) -> Callable[..., core.ClosedJaxpr | tuple[core.ClosedJaxpr, Any]]:
26032596
"""Create a function that returns the jaxpr of ``fun`` given example args.
26042597
@@ -2666,8 +2659,7 @@ def make_jaxpr(
26662659
@api_boundary
26672660
def make_jaxpr_f(*args, **kwargs):
26682661
with core.extend_axis_env_nd(axis_env or []):
2669-
traced = jit(fun, static_argnums=static_argnums,
2670-
abstracted_axes=abstracted_axes).trace(*args, **kwargs)
2662+
traced = jit(fun, static_argnums=static_argnums).trace(*args, **kwargs)
26712663
# `jit` converts tracers in consts to args but `make_jaxpr` callers expect
26722664
# consts not to be converted.
26732665
num_consts = traced._num_consts

jax/_src/config.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1785,16 +1785,6 @@ def _validate_default_device(val):
17851785
default=False,
17861786
help=('Enables lowering BCOO ops to cuSparse.'))
17871787

1788-
# TODO(mattjj): remove this flag when we ensure we only succeed at trace-staging
1789-
# if the intended backend can handle lowering the result
1790-
dynamic_shapes = bool_state(
1791-
name='jax_dynamic_shapes',
1792-
default=False,
1793-
help=('Enables experimental features for staging out computations with '
1794-
'dynamic shapes.'),
1795-
include_in_jit_key=True,
1796-
include_in_trace_context=True)
1797-
17981788
# This is for stackless backward compat with e.g. equinox
17991789
eager_constant_folding = bool_state(
18001790
name='eager_constant_folding',

0 commit comments

Comments
 (0)