Skip to content

Commit 43f51d7

Browse files
Michael Levesque-Dionjax authors
authored andcommitted
Clean up version switches from dense array migration
PiperOrigin-RevId: 637955865
1 parent 8b95853 commit 43f51d7

File tree

10 files changed

+42
-68
lines changed

10 files changed

+42
-68
lines changed

jax/_src/interpreters/mlir.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,7 @@ def dense_int_elements(xs) -> ir.DenseIntElementsAttr:
9090
return type_cast(ir.DenseIntElementsAttr,
9191
ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64)))
9292

93-
def dense_int_array(xs) -> ir.DenseElementsAttr | ir.DenseI64ArrayAttr:
94-
# TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v5 or higher
95-
if hlo.get_api_version() < 5:
96-
return dense_int_elements(xs)
97-
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) # type: ignore
98-
99-
# TODO: b/321794305 - delete this when jaxlib is on StableHLO API v6 or higher
100-
def dense_int_array_v6(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr:
101-
if hlo.get_api_version() < 6:
102-
return dense_int_elements(xs)
93+
def dense_int_array(xs) -> ir.DenseI64ArrayAttr:
10394
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) # type: ignore
10495

10596
def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr:
@@ -111,10 +102,7 @@ def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr:
111102
return ir.DenseElementsAttr.get(
112103
a, type=ir.IntegerType.get_signless(1), shape=[len(xs)])
113104

114-
def dense_bool_array(xs: Sequence[bool]) -> ir.DenseElementsAttr | ir.DenseBoolArrayAttr:
115-
# TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v6 or higher
116-
if hlo.get_api_version() < 6:
117-
return dense_bool_elements(xs)
105+
def dense_bool_array(xs: Sequence[bool]) -> ir.DenseBoolArrayAttr:
118106
return ir.DenseBoolArrayAttr.get(xs) # type: ignore
119107

120108
def i32_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), i)
@@ -321,7 +309,7 @@ def _ndarray_constant_handler(val: np.ndarray | np.generic) -> Sequence[ir.Value
321309
ir.RankedTensorType.get(
322310
val.shape, dtype_to_ir_type(collapsed_val.dtype)), # type: ignore
323311
_numpy_array_constant(collapsed_val)[0],
324-
dense_int_array_v6(other_axes))
312+
dense_int_array(other_axes))
325313
return (out,)
326314
else:
327315
return _numpy_array_constant(val)
@@ -1885,14 +1873,14 @@ def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue,
18851873
return hlo.dynamic_broadcast_in_dim(
18861874
aval_to_ir_type(aval_out), op,
18871875
shape,
1888-
dense_int_array_v6(broadcast_dimensions),
1876+
dense_int_array(broadcast_dimensions),
18891877
)
18901878
else:
18911879
assert all(d != ir.ShapedType.get_dynamic_size()
18921880
for d in aval_out.shape), aval_out # type: ignore
18931881
return hlo.broadcast_in_dim(
18941882
aval_to_ir_type(aval_out), op,
1895-
dense_int_array_v6(broadcast_dimensions))
1883+
dense_int_array(broadcast_dimensions))
18961884

18971885
def multi_broadcast_in_dim(ctx: LoweringRuleContext,
18981886
ops: Sequence[ir.Value],
@@ -2725,10 +2713,10 @@ def prep_one_pad(pad_lo_hi: tuple[core.DimSize, core.DimSize]):
27252713
rw = hlo.ReduceWindowOp(
27262714
list(map(aval_to_ir_type, out_avals)),
27272715
operands, init_values,
2728-
dense_int_array_v6(window_dimensions),
2729-
window_strides=dense_int_array_v6(window_strides),
2730-
base_dilations=dense_int_array_v6(base_dilation),
2731-
window_dilations=dense_int_array_v6(window_dilation),
2716+
dense_int_array(window_dimensions),
2717+
window_strides=dense_int_array(window_strides),
2718+
base_dilations=dense_int_array(base_dilation),
2719+
window_dilations=dense_int_array(window_dilation),
27322720
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
27332721
shape=[len(padding), 2]))
27342722
reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))

jax/_src/lax/convolution.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -719,10 +719,10 @@ def _conv_general_dilated_lower(
719719
dimension_numbers=dnums,
720720
feature_group_count=mlir.i64_attr(feature_group_count),
721721
batch_group_count=mlir.i64_attr(batch_group_count),
722-
window_strides=mlir.dense_int_array_v6(window_strides),
722+
window_strides=mlir.dense_int_array(window_strides),
723723
padding=mlir.dense_int_elements(padding),
724-
lhs_dilation=mlir.dense_int_array_v6(lhs_dilation),
725-
rhs_dilation=mlir.dense_int_array_v6(rhs_dilation),
724+
lhs_dilation=mlir.dense_int_array(lhs_dilation),
725+
rhs_dilation=mlir.dense_int_array(rhs_dilation),
726726
window_reversal=window_reversal,
727727
precision_config=lax.precision_attr(precision))
728728
]
@@ -744,9 +744,9 @@ def prep_one_pad(pad_lo_hi: tuple[core.DimSize, core.DimSize]):
744744
dimension_numbers=dnums,
745745
feature_group_count=mlir.i64_attr(feature_group_count),
746746
batch_group_count=mlir.i64_attr(batch_group_count),
747-
window_strides=mlir.dense_int_array_v6(window_strides),
748-
lhs_dilation=mlir.dense_int_array_v6(lhs_dilation),
749-
rhs_dilation=mlir.dense_int_array_v6(rhs_dilation),
747+
window_strides=mlir.dense_int_array(window_strides),
748+
lhs_dilation=mlir.dense_int_array(lhs_dilation),
749+
rhs_dilation=mlir.dense_int_array(rhs_dilation),
750750
window_reversal=window_reversal,
751751
precision_config=lax.precision_attr(precision))
752752
]

jax/_src/lax/lax.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,7 +1760,7 @@ def broadcast_hlo(
17601760
for aval, arg in zip(avals, args):
17611761
if aval.shape != aval_out.shape:
17621762
assert len(aval.shape) <= len(aval_out.shape), (aval, aval_out)
1763-
dims = mlir.dense_int_array_v6(
1763+
dims = mlir.dense_int_array(
17641764
range(len(aval_out.shape) - len(aval.shape), len(aval_out.shape)))
17651765
if any(isinstance(d, ir.Value) for d in aval_out.shape):
17661766
arg = hlo.dynamic_broadcast_in_dim(
@@ -3963,7 +3963,7 @@ def _reduce_lower(ctx, *values, computation, jaxpr, dimensions):
39633963
operands, init_values = util.split_list(values, [len(values) // 2])
39643964
init_value_avals = ctx.avals_in[len(values) // 2:]
39653965
op = hlo.ReduceOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
3966-
operands, init_values, mlir.dense_int_array_v6(dimensions))
3966+
operands, init_values, mlir.dense_int_array(dimensions))
39673967
ir_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
39683968
reducer = op.regions[0].blocks.append(*(ir_types + ir_types))
39693969
with ir.InsertionPoint(reducer):
@@ -4174,7 +4174,7 @@ def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes):
41744174
dtype = aval_out.dtype
41754175
op = hlo.ReduceOp([mlir.aval_to_ir_type(aval_out)], [x],
41764176
mlir.ir_constants(unit_factory(aval_out.dtype)),
4177-
mlir.dense_int_array_v6(axes))
4177+
mlir.dense_int_array(axes))
41784178
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), dtype))
41794179
reducer_region = op.regions[0].blocks.append(scalar_type, scalar_type)
41804180
with ir.InsertionPoint(reducer_region):

jax/_src/lax/parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1271,7 +1271,7 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
12711271
broadcast_dimensions = [i for i in range(len(new_shape)) if i != all_gather_dimension]
12721272
x = hlo.broadcast_in_dim(
12731273
mlir.aval_to_ir_type(x_aval.update(shape=new_shape)), x,
1274-
mlir.dense_int_array_v6(broadcast_dimensions))
1274+
mlir.dense_int_array(broadcast_dimensions))
12751275
replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name,
12761276
axis_index_groups)
12771277
if is_spmd:

jax/_src/lax/slicing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1845,7 +1845,7 @@ def _gather_lower(ctx, operand, indices, *,
18451845
operand,
18461846
indices,
18471847
dnums,
1848-
mlir.dense_int_array_v6(slice_sizes),
1848+
mlir.dense_int_array(slice_sizes),
18491849
indices_are_sorted=ir.BoolAttr.get(indices_are_sorted))]
18501850

18511851
mlir.register_lowering(gather_p, _gather_lower)

jax/_src/lax/windowed_reductions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -665,8 +665,8 @@ def _select_and_scatter_lower(
665665
operand,
666666
source,
667667
init_value,
668-
window_dimensions=mlir.dense_int_array_v6(window_dimensions),
669-
window_strides=mlir.dense_int_array_v6(window_strides),
668+
window_dimensions=mlir.dense_int_array(window_dimensions),
669+
window_strides=mlir.dense_int_array(window_strides),
670670
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
671671
shape=(len(padding), 2)))
672672
select = op.select.blocks.append(scalar_type, scalar_type)

jax/experimental/export/_export.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -514,26 +514,23 @@ def export_sharding(s: LoweringSharding,
514514

515515
def _module_to_bytecode(module: ir.Module) -> bytes:
516516
mlir_str = mlir.module_to_bytecode(module)
517-
if hlo.get_api_version() < 4:
518-
target_version = hlo.get_earliest_forward_compatible_version()
519-
else:
520-
# `target_version` is used to manage situations when a StableHLO producer
521-
# (in this case, jax2tf) and a StableHLO consumer were built using
522-
# different versions of StableHLO.
523-
#
524-
# Each StableHLO version `producer_version` has a compatibility window,
525-
# i.e. range of versions [`consumer_version_min`, `consumer_version_max`],
526-
# where StableHLO portable artifacts serialized by `producer_version`
527-
# can be deserialized by `consumer_version` within the window.
528-
# See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md
529-
# for the exact extent of these compatibility guarantees.
530-
#
531-
# `hlo.get_minimum_version()` returns `consumer_version_min`
532-
# for the current version of StableHLO. We are using it here to maximize
533-
# forward compatibility, i.e. to maximize how far into the past we can go
534-
# and still have the payloads produced by `serialize_portable_artifact`
535-
# compatible with potential consumers from the past.
536-
target_version = hlo.get_minimum_version()
517+
# `target_version` is used to manage situations when a StableHLO producer
518+
# (in this case, jax2tf) and a StableHLO consumer were built using
519+
# different versions of StableHLO.
520+
#
521+
# Each StableHLO version `producer_version` has a compatibility window,
522+
# i.e. range of versions [`consumer_version_min`, `consumer_version_max`],
523+
# where StableHLO portable artifacts serialized by `producer_version`
524+
# can be deserialized by `consumer_version` within the window.
525+
# See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md
526+
# for the exact extent of these compatibility guarantees.
527+
#
528+
# `hlo.get_minimum_version()` returns `consumer_version_min`
529+
# for the current version of StableHLO. We are using it here to maximize
530+
# forward compatibility, i.e. to maximize how far into the past we can go
531+
# and still have the payloads produced by `serialize_portable_artifact`
532+
# compatible with potential consumers from the past.
533+
target_version = hlo.get_minimum_version()
537534
module_serialized = xla_client._xla.mlir.serialize_portable_artifact(
538535
mlir_str, target_version)
539536
return module_serialized

jax/interpreters/mlir.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
dense_bool_elements as dense_bool_elements,
3838
dense_bool_array as dense_bool_array,
3939
dense_int_array as dense_int_array,
40-
dense_int_array_v6 as dense_int_array_v6,
4140
dense_int_elements as dense_int_elements,
4241
dtype_to_ir_type as dtype_to_ir_type,
4342
emit_python_callback as emit_python_callback,

jaxlib/gpu_solver.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
from .hlo_helpers import (
3030
DimensionSize, ShapeTypePair, mk_result_types_and_shapes,
31-
custom_call, ensure_hlo_s32, hlo_s32, dense_int_array, dense_int_array_v6)
31+
custom_call, ensure_hlo_s32, hlo_s32, dense_int_array)
3232

3333
try:
3434
from .cuda import _blas as _cublas # pytype: disable=import-error
@@ -536,14 +536,13 @@ def _sytrd_hlo(platform, gpu_solver, dtype, a, *, lower):
536536
# simply copy it back to where it needs to be:
537537
intattr = lambda xs: ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))
538538
intarrattr = lambda xs: dense_int_array(np.asarray(xs, np.int64))
539-
intarrattr_v6 = lambda xs: dense_int_array_v6(np.asarray(xs, np.int64))
540539
if not lower and platform == "cu" and m > 1:
541540
start = (0,) * len(batch_dims) + (0,)
542541
end = batch_dims + (1,)
543542
s = hlo.slice(
544543
e, intarrattr(start), intarrattr(end), intarrattr([1] * len(start)))
545544
s_type = ir.RankedTensorType.get(batch_dims + (1, 1), diag_type)
546-
s = hlo.broadcast_in_dim(s_type, s, intarrattr_v6(range(len(dims) - 1)))
545+
s = hlo.broadcast_in_dim(s_type, s, intarrattr(range(len(dims) - 1)))
547546
# The diagonals are always real; convert to complex if needed.
548547
s = hlo.convert(
549548
ir.RankedTensorType.get(s_type.shape, a_type.element_type), s)

jaxlib/hlo_helpers.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,16 +110,7 @@ def hlo_s32(x: int):
110110
def ensure_hlo_s32(x: DimensionSize):
111111
return hlo_s32(x) if isinstance(x, int) else x
112112

113-
def dense_int_array(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr:
114-
# TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v5 or higher
115-
if hlo.get_api_version() < 5:
116-
return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))
117-
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64))
118-
119-
# TODO: b/321794305 - delete this when jaxlib is on StableHLO API v6 or higher
120-
def dense_int_array_v6(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr:
121-
if hlo.get_api_version() < 6:
122-
return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))
113+
def dense_int_array(xs) -> ir.DenseI64ArrayAttr:
123114
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64))
124115

125116
def hlo_min(x: DimensionSize, y: DimensionSize) -> DimensionSize:

0 commit comments

Comments
 (0)