Skip to content

Commit 249b7a7

Browse files
pschuhGoogle-ML-Automation
authored andcommitted
Add ResultHandler.wrap which allows us to avoid exploding PRNG keys.
PiperOrigin-RevId: 842010889
1 parent c74b2a5 commit 249b7a7

File tree

10 files changed

+109
-44
lines changed

10 files changed

+109
-44
lines changed

jax/_src/array.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from jax._src.interpreters import pxla
4040
from jax._src.layout import AutoLayout, Format, Layout
4141
from jax._src.lib import _jax
42+
from jax._src.lib import jaxlib_extension_version
4243
from jax._src.lib import xla_client as xc
4344
from jax._src.mesh import empty_concrete_mesh
4445
from jax._src.sharding import Sharding
@@ -1285,7 +1286,14 @@ def _array_shard_arg(xs, shardings, layouts, copy_semantics):
12851286

12861287
def _array_global_result_handler(global_aval, out_sharding, committed):
12871288
if global_aval.dtype == dtypes.float0:
1288-
return lambda _: np.zeros(global_aval.shape, dtypes.float0)
1289+
def handler(xs):
1290+
return np.zeros(global_aval.shape, dtypes.float0)
1291+
if jaxlib_extension_version >= 390:
1292+
phys_aval = core.physical_aval(global_aval)
1293+
return xc.array_result_handler(phys_aval, out_sharding, committed=committed,
1294+
_skip_checks=True).wrap(handler)
1295+
else:
1296+
return handler
12891297
if dtypes.issubdtype(global_aval.dtype, dtypes.extended):
12901298
return global_aval.dtype._rules.global_sharded_result_handler(
12911299
global_aval, out_sharding, committed)
@@ -1297,7 +1305,14 @@ def _array_global_result_handler(global_aval, out_sharding, committed):
12971305
# Only used for Arrays that come out of pmap.
12981306
def _array_local_result_handler(aval, sharding, indices):
12991307
if aval.dtype == dtypes.float0:
1300-
return lambda _: np.zeros(aval.shape, dtypes.float0)
1308+
def handler(xs):
1309+
return np.zeros(aval.shape, dtypes.float0)
1310+
if jaxlib_extension_version >= 390:
1311+
phys_aval = core.physical_aval(aval)
1312+
return xc.array_result_handler(phys_aval, sharding, committed=True,
1313+
_skip_checks=True).wrap(handler)
1314+
else:
1315+
return handler
13011316
if dtypes.issubdtype(aval.dtype, dtypes.extended):
13021317
return aval.dtype._rules.local_sharded_result_handler(
13031318
aval, sharding, indices)
@@ -1326,9 +1341,13 @@ def _token_shard_arg(xs, shardings, layouts, copy_semantics):
13261341
def _token_global_result_handler(global_aval, out_sharding, committed):
13271342
array_handler = _array_global_result_handler(
13281343
core.get_token_aval(), out_sharding, committed)
1329-
1330-
def wrapper(*args, **kwargs):
1331-
out_buf = array_handler(*args, **kwargs)
1332-
return core.Token(out_buf)
1333-
return wrapper
1344+
if jaxlib_extension_version >= 390:
1345+
def wrapper(array):
1346+
return core.Token(array)
1347+
return array_handler.wrap(wrapper) # type: ignore
1348+
else:
1349+
def old_wrapper(*args, **kwargs):
1350+
out_buf = array_handler(*args, **kwargs)
1351+
return core.Token(out_buf)
1352+
return old_wrapper
13341353
pxla.global_result_handlers[core.AbstractToken] = _token_global_result_handler

jax/_src/interpreters/pxla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def local_aval_to_result_handler(
302302
raise TypeError(
303303
f"No pxla_result_handler for type: {type(aval)}") from err
304304

305-
PxlaResultHandler = Callable[..., Callable[[Any], Any]]
305+
PxlaResultHandler = Callable[..., xc._xla.ResultHandler]
306306
local_result_handlers: dict[type[core.AbstractValue], PxlaResultHandler] = {}
307307

308308

jax/_src/lax/lax.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from jax._src.lax.utils import (
6464
input_dtype, dtype_to_string, standard_multi_result_abstract_eval,
6565
standard_primitive)
66+
from jax._src.lib import jaxlib_extension_version
6667
from jax._src.lib.mlir import ir
6768
from jax._src.lib.mlir.dialects import chlo
6869
from jax._src.lib.mlir.dialects import hlo
@@ -9246,10 +9247,14 @@ def global_sharded_result_handler(aval, out_sharding, committed):
92469247
else:
92479248
phys_sharding = out_sharding
92489249
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed)
9249-
9250-
def handler(bufs):
9251-
return core.DArray(aval, phys_handler(bufs))
9252-
return handler
9250+
if jaxlib_extension_version >= 390:
9251+
def handler(arr):
9252+
return core.DArray(aval, arr)
9253+
return phys_handler.wrap(handler)
9254+
else:
9255+
def handler(bufs):
9256+
return core.DArray(aval, phys_handler(bufs))
9257+
return handler
92539258

92549259

92559260
core.bint._rules = BIntRules

jax/_src/prng.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from jax._src.lax import control_flow as lax_control_flow
4343
from jax._src.lax import lax
4444
from jax._src.lax import slicing as lax_slicing
45+
from jax._src.lib import jaxlib_extension_version
4546
from jax._src.lib import gpu_prng
4647
from jax._src.lib import xla_client as xc
4748
from jax._src.lib.mlir import ir
@@ -402,10 +403,16 @@ def local_sharded_result_handler(aval, sharding, indices):
402403
phys_handler = phys_handler_maker(phys_aval, phys_sharding, phys_indices)
403404

404405
# set up a handler that calls the physical one and wraps back up
405-
def handler(bufs):
406-
return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs))
406+
if jaxlib_extension_version >= 390:
407+
def handler(arr):
408+
return PRNGKeyArray(aval.dtype._impl, arr)
407409

408-
return handler
410+
return phys_handler.wrap(handler)
411+
else:
412+
def handler(bufs):
413+
return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs))
414+
415+
return handler
409416

410417
@staticmethod
411418
def global_sharded_result_handler(aval, out_sharding, committed):
@@ -414,9 +421,14 @@ def global_sharded_result_handler(aval, out_sharding, committed):
414421

415422
phys_sharding = physical_sharding(aval, out_sharding)
416423
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed)
417-
def handler(bufs):
418-
return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs))
419-
return handler
424+
if jaxlib_extension_version >= 390:
425+
def handler(bufs):
426+
return PRNGKeyArray(aval.dtype._impl, bufs)
427+
return phys_handler.wrap(handler)
428+
else:
429+
def handler(bufs):
430+
return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs))
431+
return handler
420432

421433
@staticmethod
422434
def make_sharded_array(aval, sharding, arrays, committed):

jaxlib/_jax/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,6 +1029,7 @@ def array_result_handler(
10291029

10301030
class ResultHandler:
10311031
def __call__(self, arg: Array | Sequence[Array], /) -> Array: ...
1032+
def wrap(self, wrapper: Callable) -> Any: ...
10321033

10331034
class DeviceList:
10341035
def __init__(self, arg: tuple[Device, ...], /) -> None: ...

jaxlib/py_array.cc

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -584,18 +584,21 @@ PyArray PyArray::MakeFromIfrtArrayAndSharding(nb_class_ptr<PyClient> py_client,
584584
std::move(ifrt_array), committed, skip_checks);
585585
}
586586

587-
PyArrayResultHandler::PyArrayResultHandler(nb::object aval, nb::object sharding,
588-
bool committed, bool skip_checks)
587+
PyArrayResultHandler::PyArrayResultHandler(
588+
nb::object aval, nb::object sharding, bool committed, bool skip_checks,
589+
std::vector<nanobind::callable> wrappers)
589590
: aval_(std::move(aval)),
590591
sharding_(std::move(sharding)),
591592
committed_(committed),
592-
skip_checks_(skip_checks) {
593+
skip_checks_(skip_checks),
594+
wrappers_(std::move(wrappers)) {
593595
weak_type_ = nb::cast<bool>(aval_.attr("weak_type"));
594596
dtype_ = nb::cast<xla::nb_dtype>(aval_.attr("dtype"));
595597
shape_ = nb::cast<std::vector<int64_t>>(aval_.attr("shape"));
596598
}
597599

598-
PyArray PyArrayResultHandler::Call(absl::Span<const PyArray> py_arrays) const {
600+
nanobind::object PyArrayResultHandler::Call(
601+
absl::Span<const PyArray> py_arrays) const {
599602
auto py_device_list = GetPyDeviceList(sharding_);
600603
if (!py_device_list.ok()) {
601604
throw nb::value_error(
@@ -610,15 +613,20 @@ PyArray PyArrayResultHandler::Call(absl::Span<const PyArray> py_arrays) const {
610613
xla::Future<>());
611614
}
612615

613-
PyArray PyArrayResultHandler::Call(nb_class_ptr<PyClient> py_client,
614-
ifrt::ArrayRef ifrt_array,
615-
xla::Future<> result_status) const {
616-
return PyArray(aval_, weak_type_, dtype_, shape_, sharding_,
617-
std::move(py_client), std::move(ifrt_array), committed_,
618-
skip_checks_, std::move(result_status));
616+
nanobind::object PyArrayResultHandler::Call(nb_class_ptr<PyClient> py_client,
617+
ifrt::ArrayRef ifrt_array,
618+
xla::Future<> result_status) const {
619+
nanobind::object result =
620+
PyArray(aval_, weak_type_, dtype_, shape_, sharding_,
621+
std::move(py_client), std::move(ifrt_array), committed_,
622+
skip_checks_, std::move(result_status));
623+
for (auto& cb : wrappers_) {
624+
result = cb(std::move(result));
625+
}
626+
return result;
619627
}
620628

621-
PyArray PyArrayResultHandler::Call(PyArray py_array) const {
629+
nanobind::object PyArrayResultHandler::Call(PyArray py_array) const {
622630
return Call(py_array.py_client(), tsl::FormRef(py_array.ifrt_array()),
623631
xla::Future<>());
624632
}
@@ -2364,7 +2372,14 @@ absl::Status PyArray::Register(nb::module_& m) {
23642372
.c_str());
23652373
},
23662374
nb::sig(
2367-
"def __call__(self, arg: Array | Sequence[Array], /) -> Array"));
2375+
"def __call__(self, arg: Array | Sequence[Array], /) -> Array"))
2376+
.def("wrap", [](const PyArrayResultHandler& self, nb::callable wrapper) {
2377+
auto wrappers = self.wrappers();
2378+
wrappers.push_back(std::move(wrapper));
2379+
return make_nb_class<PyArrayResultHandler>(
2380+
self.aval(), self.sharding(), self.committed(), self.skip_checks(),
2381+
std::move(wrappers));
2382+
});
23682383

23692384
return absl::OkStatus();
23702385
}

jaxlib/py_array.h

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -359,13 +359,22 @@ class PyArray : public nanobind::object {
359359
class PyArrayResultHandler {
360360
public:
361361
PyArrayResultHandler(nanobind::object aval, nanobind::object sharding,
362-
bool committed, bool skip_checks);
362+
bool committed, bool skip_checks,
363+
std::vector<nanobind::callable> wrappers = {});
363364

364-
PyArray Call(absl::Span<const PyArray> py_arrays) const;
365-
PyArray Call(PyArray py_array) const;
365+
nanobind::object Call(absl::Span<const PyArray> py_arrays) const;
366+
nanobind::object Call(PyArray py_array) const;
366367

367-
PyArray Call(nb_class_ptr<PyClient> py_client, xla::ifrt::ArrayRef ifrt_array,
368-
xla::Future<> result_status = xla::Future<>()) const;
368+
nanobind::object Call(nb_class_ptr<PyClient> py_client,
369+
xla::ifrt::ArrayRef ifrt_array,
370+
xla::Future<> result_status = xla::Future<>()) const;
371+
372+
const std::vector<nanobind::callable>& wrappers() const { return wrappers_; }
373+
374+
nanobind::object aval() const { return aval_; }
375+
nanobind::object sharding() const { return sharding_; }
376+
bool committed() const { return committed_; }
377+
bool skip_checks() const { return skip_checks_; }
369378

370379
private:
371380
nanobind::object aval_;
@@ -376,6 +385,7 @@ class PyArrayResultHandler {
376385

377386
xla::nb_dtype dtype_;
378387
std::vector<int64_t> shape_;
388+
std::vector<nanobind::callable> wrappers_;
379389
};
380390

381391
absl::StatusOr<nanobind::object> CudaArrayInterfaceToBuffer(

jaxlib/xla_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
# Please suffix the version number with a brief description of your change
4848
# in a comment. The goal here is to force a merge conflict if two changes
4949
# attempt to grab the same version number.
50-
_version = 389 # LoadedExecutable.serialize
50+
_version = 390 # ResultHandler.wrap
5151

5252
# An internal increasing version number for protecting jaxlib code against
5353
# ifrt changes.

tests/dtypes_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from jax._src import literals
3535
from jax._src import test_util as jtu
3636
from jax._src.lax import lax as lax_internal
37+
from jax._src.lib import jaxlib_extension_version
3738

3839
config.parse_flags_with_absl()
3940

@@ -872,6 +873,8 @@ def global_sharded_result_handler(aval, out_sharding, committed):
872873
phys_aval = core.physical_aval(aval)
873874
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
874875
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed)
876+
if jaxlib_extension_version >= 390:
877+
return phys_handler.wrap(lambda arr: earray.EArray(aval, arr))
875878
return lambda bufs: earray.EArray(aval, phys_handler(bufs))
876879

877880
@dataclasses.dataclass(frozen=True)

tests/lax_test.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from jax._src.internal_test_util import lax_test_util
4949
from jax._src.lax import lax as lax_internal
5050
from jax._src.lax import utils as lax_utils
51+
from jax._src.lib import jaxlib_extension_version
5152
from jax._src.util import safe_zip
5253
from jax._src.tree_util import tree_map
5354

@@ -3992,14 +3993,13 @@ def handler(_, buf):
39923993

39933994
@staticmethod
39943995
def global_sharded_result_handler(aval, out_sharding, committed):
3995-
def handler(arr):
3996-
from jax._src.array import ArrayImpl
3997-
if isinstance(arr, ArrayImpl):
3998-
buf, = arr._arrays
3999-
else:
4000-
buf, = arr
4001-
return FooArray(aval.shape, buf)
4002-
return handler
3996+
phys_sharding = out_sharding # unlike KeyTyRules, assume same shape
3997+
phys_aval = core.physical_aval(aval)
3998+
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
3999+
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed)
4000+
if jaxlib_extension_version >= 390:
4001+
return phys_handler.wrap(lambda arr: FooArray(aval.shape, arr))
4002+
return lambda bufs: FooArray(aval.shape, phys_handler(bufs))
40034003

40044004

40054005
class FooTy(dtypes.ExtendedDType):

0 commit comments

Comments
 (0)