File tree Expand file tree Collapse file tree 4 files changed +21
-14
lines changed
Expand file tree Collapse file tree 4 files changed +21
-14
lines changed Original file line number Diff line number Diff line change @@ -407,6 +407,7 @@ nanobind_pywrap_extension(
407407 "@xla//xla/python:types" ,
408408 "@xla//xla/python:version" ,
409409 "@xla//xla/python/ifrt" ,
410+ "@xla//xla/python/ifrt:attribute_map" ,
410411 "@xla//xla/python/ifrt:plugin_program" ,
411412 "@xla//xla/python/ifrt:plugin_program_serdes" ,
412413 "@xla//xla/python/pjrt_ifrt" ,
Original file line number Diff line number Diff line change @@ -65,6 +65,7 @@ limitations under the License.
6565#include " xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"
6666#include " xla/pjrt/status_casters.h"
6767#include " xla/python/ifrt/array.h"
68+ #include " xla/python/ifrt/attribute_map.h"
6869#include " xla/python/ifrt/device.h"
6970#include " xla/python/ifrt/device_list.h"
7071#include " xla/python/ifrt/executable.h"
@@ -911,11 +912,12 @@ NB_MODULE(_jax, m) {
911912 .def (" __getattr__" ,
912913 [](xla::ifrt::Topology& topology,
913914 std::string_view name) -> nb::object {
914- const auto & attrs = topology.Attributes ().map ();
915- auto it = attrs.find (name);
916- if (it != attrs.end ()) {
915+ auto value =
916+ topology.Attributes ().Get <xla::ifrt::AttributeMap::Value>(
917+ std::string (name));
918+ if (value.ok ()) {
917919 return std::visit ([](auto && v) { return nb::cast (v.value ); },
918- it-> second );
920+ *value );
919921 }
920922 throw nb::attribute_error (
921923 absl::StrCat (" Unknown attribute " , name).c_str ());
Original file line number Diff line number Diff line change @@ -74,6 +74,7 @@ limitations under the License.
7474#include " xla/pjrt/pjrt_layout.h"
7575#include " xla/pjrt/status_casters.h"
7676#include " xla/python/ifrt/array.h"
77+ #include " xla/python/ifrt/attribute_map.h"
7778#include " xla/python/ifrt/client.h"
7879#include " xla/python/ifrt/compiler.h"
7980#include " xla/python/ifrt/device.h"
@@ -1031,11 +1032,12 @@ PyType_Slot PyClient::slots_[] = {
10311032 nb::arg (" dtype" ), nb::arg (" shard_shape" ), nb::arg (" device" ))
10321033 .def (" __getattr__" ,
10331034 [](PyClient& client, std::string_view name) -> nb::object {
1034- const auto & attrs = client.Attributes ().map ();
1035- auto it = attrs.find (name);
1036- if (it != attrs.end ()) {
1035+ auto value =
1036+ client.Attributes ().Get <xla::ifrt::AttributeMap::Value>(
1037+ std::string (name));
1038+ if (value.ok ()) {
10371039 return std::visit ([](auto && v) { return nb::cast (v.value ); },
1038- it-> second );
1040+ *value );
10391041 }
10401042 throw nb::attribute_error (
10411043 absl::StrCat (" Unknown attribute " , name).c_str ());
Original file line number Diff line number Diff line change @@ -40,6 +40,7 @@ limitations under the License.
4040#include " jaxlib/py_memory_space.h"
4141#include " jaxlib/python_ref_manager.h"
4242#include " xla/pjrt/status_casters.h"
43+ #include " xla/python/ifrt/attribute_map.h"
4344#include " xla/python/ifrt/device.h"
4445#include " xla/python/nb_helpers.h"
4546#include " xla/python/pjrt_ifrt/pjrt_client.h"
@@ -278,12 +279,13 @@ PyType_Slot PyDevice::slots_[] = {
278279 }
279280 try {
280281 auto device = nb::cast<PyDevice*>(nb::handle (self));
281- auto name = nb::cast<std::string_view>(nb::handle (key));
282- const auto & attrs = device->device_ ->Attributes ().map ();
283- auto it = attrs.find (name);
284- if (it != attrs.end ()) {
285- auto result = std::visit ([](auto && v) { return nb::cast (v.value ); },
286- it->second );
282+ auto name = nb::cast<std::string>(nb::handle (key));
283+ auto value =
284+ device->device_ ->Attributes ().Get <xla::ifrt::AttributeMap::Value>(
285+ name);
286+ if (value.ok ()) {
287+ auto result =
288+ std::visit ([](auto && v) { return nb::cast (v.value ); }, *value);
287289 return result.release ().ptr ();
288290 }
289291 PyErr_SetNone (PyExc_AttributeError);
You can’t perform that action at this time.
0 commit comments