@@ -584,18 +584,21 @@ PyArray PyArray::MakeFromIfrtArrayAndSharding(nb_class_ptr<PyClient> py_client,
584584 std::move (ifrt_array), committed, skip_checks);
585585}
586586
587- PyArrayResultHandler::PyArrayResultHandler (nb::object aval, nb::object sharding,
588- bool committed, bool skip_checks)
587+ PyArrayResultHandler::PyArrayResultHandler (
588+ nb::object aval, nb::object sharding, bool committed, bool skip_checks,
589+ std::vector<nanobind::callable> wrappers)
589590 : aval_(std::move(aval)),
590591 sharding_(std::move(sharding)),
591592 committed_(committed),
592- skip_checks_(skip_checks) {
593+ skip_checks_(skip_checks),
594+ wrappers_(std::move(wrappers)) {
593595 weak_type_ = nb::cast<bool >(aval_.attr (" weak_type" ));
594596 dtype_ = nb::cast<xla::nb_dtype>(aval_.attr (" dtype" ));
595597 shape_ = nb::cast<std::vector<int64_t >>(aval_.attr (" shape" ));
596598}
597599
598- PyArray PyArrayResultHandler::Call (absl::Span<const PyArray> py_arrays) const {
600+ nanobind::object PyArrayResultHandler::Call (
601+ absl::Span<const PyArray> py_arrays) const {
599602 auto py_device_list = GetPyDeviceList (sharding_);
600603 if (!py_device_list.ok ()) {
601604 throw nb::value_error (
@@ -610,15 +613,20 @@ PyArray PyArrayResultHandler::Call(absl::Span<const PyArray> py_arrays) const {
610613 xla::Future<>());
611614}
612615
613- PyArray PyArrayResultHandler::Call (nb_class_ptr<PyClient> py_client,
614- ifrt::ArrayRef ifrt_array,
615- xla::Future<> result_status) const {
616- return PyArray (aval_, weak_type_, dtype_, shape_, sharding_,
617- std::move (py_client), std::move (ifrt_array), committed_,
618- skip_checks_, std::move (result_status));
616+ nanobind::object PyArrayResultHandler::Call (nb_class_ptr<PyClient> py_client,
617+ ifrt::ArrayRef ifrt_array,
618+ xla::Future<> result_status) const {
619+ nanobind::object result =
620+ PyArray (aval_, weak_type_, dtype_, shape_, sharding_,
621+ std::move (py_client), std::move (ifrt_array), committed_,
622+ skip_checks_, std::move (result_status));
623+ for (auto & cb : wrappers_) {
624+ result = cb (std::move (result));
625+ }
626+ return result;
619627}
620628
621- PyArray PyArrayResultHandler::Call (PyArray py_array) const {
629+ nanobind::object PyArrayResultHandler::Call (PyArray py_array) const {
622630 return Call (py_array.py_client (), tsl::FormRef (py_array.ifrt_array ()),
623631 xla::Future<>());
624632}
@@ -2364,7 +2372,14 @@ absl::Status PyArray::Register(nb::module_& m) {
23642372 .c_str ());
23652373 },
23662374 nb::sig (
2367- " def __call__(self, arg: Array | Sequence[Array], /) -> Array" ));
2375+ " def __call__(self, arg: Array | Sequence[Array], /) -> Array" ))
2376+ .def (" wrap" , [](const PyArrayResultHandler& self, nb::callable wrapper) {
2377+ auto wrappers = self.wrappers ();
2378+ wrappers.push_back (std::move (wrapper));
2379+ return make_nb_class<PyArrayResultHandler>(
2380+ self.aval (), self.sharding (), self.committed (), self.skip_checks (),
2381+ std::move (wrappers));
2382+ });
23682383
23692384 return absl::OkStatus ();
23702385}
0 commit comments