@@ -90,16 +90,7 @@ def dense_int_elements(xs) -> ir.DenseIntElementsAttr:
9090 return type_cast (ir .DenseIntElementsAttr ,
9191 ir .DenseIntElementsAttr .get (np .asarray (xs , np .int64 )))
9292
93- def dense_int_array (xs ) -> ir .DenseElementsAttr | ir .DenseI64ArrayAttr :
94- # TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v5 or higher
95- if hlo .get_api_version () < 5 :
96- return dense_int_elements (xs )
97- return ir .DenseI64ArrayAttr .get (np .asarray (xs , np .int64 )) # type: ignore
98-
99- # TODO: b/321794305 - delete this when jaxlib is on StableHLO API v6 or higher
100- def dense_int_array_v6 (xs ) -> ir .DenseIntElementsAttr | ir .DenseI64ArrayAttr :
101- if hlo .get_api_version () < 6 :
102- return dense_int_elements (xs )
93+ def dense_int_array (xs ) -> ir .DenseI64ArrayAttr :
10394 return ir .DenseI64ArrayAttr .get (np .asarray (xs , np .int64 )) # type: ignore
10495
10596def dense_bool_elements (xs : Sequence [bool ]) -> ir .DenseElementsAttr :
@@ -111,10 +102,7 @@ def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr:
111102 return ir .DenseElementsAttr .get (
112103 a , type = ir .IntegerType .get_signless (1 ), shape = [len (xs )])
113104
114- def dense_bool_array (xs : Sequence [bool ]) -> ir .DenseElementsAttr | ir .DenseBoolArrayAttr :
115- # TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v6 or higher
116- if hlo .get_api_version () < 6 :
117- return dense_bool_elements (xs )
105+ def dense_bool_array (xs : Sequence [bool ]) -> ir .DenseBoolArrayAttr :
118106 return ir .DenseBoolArrayAttr .get (xs ) # type: ignore
119107
120108def i32_attr (i ): return ir .IntegerAttr .get (ir .IntegerType .get_signless (32 ), i )
@@ -321,7 +309,7 @@ def _ndarray_constant_handler(val: np.ndarray | np.generic) -> Sequence[ir.Value
321309 ir .RankedTensorType .get (
322310 val .shape , dtype_to_ir_type (collapsed_val .dtype )), # type: ignore
323311 _numpy_array_constant (collapsed_val )[0 ],
324- dense_int_array_v6 (other_axes ))
312+ dense_int_array (other_axes ))
325313 return (out ,)
326314 else :
327315 return _numpy_array_constant (val )
@@ -1885,14 +1873,14 @@ def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue,
18851873 return hlo .dynamic_broadcast_in_dim (
18861874 aval_to_ir_type (aval_out ), op ,
18871875 shape ,
1888- dense_int_array_v6 (broadcast_dimensions ),
1876+ dense_int_array (broadcast_dimensions ),
18891877 )
18901878 else :
18911879 assert all (d != ir .ShapedType .get_dynamic_size ()
18921880 for d in aval_out .shape ), aval_out # type: ignore
18931881 return hlo .broadcast_in_dim (
18941882 aval_to_ir_type (aval_out ), op ,
1895- dense_int_array_v6 (broadcast_dimensions ))
1883+ dense_int_array (broadcast_dimensions ))
18961884
18971885def multi_broadcast_in_dim (ctx : LoweringRuleContext ,
18981886 ops : Sequence [ir .Value ],
@@ -2725,10 +2713,10 @@ def prep_one_pad(pad_lo_hi: tuple[core.DimSize, core.DimSize]):
27252713 rw = hlo .ReduceWindowOp (
27262714 list (map (aval_to_ir_type , out_avals )),
27272715 operands , init_values ,
2728- dense_int_array_v6 (window_dimensions ),
2729- window_strides = dense_int_array_v6 (window_strides ),
2730- base_dilations = dense_int_array_v6 (base_dilation ),
2731- window_dilations = dense_int_array_v6 (window_dilation ),
2716+ dense_int_array (window_dimensions ),
2717+ window_strides = dense_int_array (window_strides ),
2718+ base_dilations = dense_int_array (base_dilation ),
2719+ window_dilations = dense_int_array (window_dilation ),
27322720 padding = ir .DenseIntElementsAttr .get (np .asarray (padding , np .int64 ),
27332721 shape = [len (padding ), 2 ]))
27342722 reducer = rw .regions [0 ].blocks .append (* (scalar_types + scalar_types ))
0 commit comments