Skip to content

Commit 14f072b

Browse files
Use ifrt::AttributeMap::Get instead of directly accessing map
Introduces a variant of Get in AttributeMap that returns the value variant as is. PiperOrigin-RevId: 842283537
1 parent 50bdc72 commit 14f072b

File tree

4 files changed

+21
-14
lines changed

4 files changed

+21
-14
lines changed

jaxlib/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ nanobind_pywrap_extension(
407407
"@xla//xla/python:types",
408408
"@xla//xla/python:version",
409409
"@xla//xla/python/ifrt",
410+
"@xla//xla/python/ifrt:attribute_map",
410411
"@xla//xla/python/ifrt:plugin_program",
411412
"@xla//xla/python/ifrt:plugin_program_serdes",
412413
"@xla//xla/python/pjrt_ifrt",

jaxlib/jax.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ limitations under the License.
6565
#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"
6666
#include "xla/pjrt/status_casters.h"
6767
#include "xla/python/ifrt/array.h"
68+
#include "xla/python/ifrt/attribute_map.h"
6869
#include "xla/python/ifrt/device.h"
6970
#include "xla/python/ifrt/device_list.h"
7071
#include "xla/python/ifrt/executable.h"
@@ -911,11 +912,12 @@ NB_MODULE(_jax, m) {
911912
.def("__getattr__",
912913
[](xla::ifrt::Topology& topology,
913914
std::string_view name) -> nb::object {
914-
const auto& attrs = topology.Attributes().map();
915-
auto it = attrs.find(name);
916-
if (it != attrs.end()) {
915+
auto value =
916+
topology.Attributes().Get<xla::ifrt::AttributeMap::Value>(
917+
std::string(name));
918+
if (value.ok()) {
917919
return std::visit([](auto&& v) { return nb::cast(v.value); },
918-
it->second);
920+
*value);
919921
}
920922
throw nb::attribute_error(
921923
absl::StrCat("Unknown attribute ", name).c_str());

jaxlib/py_client.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ limitations under the License.
7474
#include "xla/pjrt/pjrt_layout.h"
7575
#include "xla/pjrt/status_casters.h"
7676
#include "xla/python/ifrt/array.h"
77+
#include "xla/python/ifrt/attribute_map.h"
7778
#include "xla/python/ifrt/client.h"
7879
#include "xla/python/ifrt/compiler.h"
7980
#include "xla/python/ifrt/device.h"
@@ -1031,11 +1032,12 @@ PyType_Slot PyClient::slots_[] = {
10311032
nb::arg("dtype"), nb::arg("shard_shape"), nb::arg("device"))
10321033
.def("__getattr__",
10331034
[](PyClient& client, std::string_view name) -> nb::object {
1034-
const auto& attrs = client.Attributes().map();
1035-
auto it = attrs.find(name);
1036-
if (it != attrs.end()) {
1035+
auto value =
1036+
client.Attributes().Get<xla::ifrt::AttributeMap::Value>(
1037+
std::string(name));
1038+
if (value.ok()) {
10371039
return std::visit([](auto&& v) { return nb::cast(v.value); },
1038-
it->second);
1040+
*value);
10391041
}
10401042
throw nb::attribute_error(
10411043
absl::StrCat("Unknown attribute ", name).c_str());

jaxlib/py_device.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ limitations under the License.
4040
#include "jaxlib/py_memory_space.h"
4141
#include "jaxlib/python_ref_manager.h"
4242
#include "xla/pjrt/status_casters.h"
43+
#include "xla/python/ifrt/attribute_map.h"
4344
#include "xla/python/ifrt/device.h"
4445
#include "xla/python/nb_helpers.h"
4546
#include "xla/python/pjrt_ifrt/pjrt_client.h"
@@ -278,12 +279,13 @@ PyType_Slot PyDevice::slots_[] = {
278279
}
279280
try {
280281
auto device = nb::cast<PyDevice*>(nb::handle(self));
281-
auto name = nb::cast<std::string_view>(nb::handle(key));
282-
const auto& attrs = device->device_->Attributes().map();
283-
auto it = attrs.find(name);
284-
if (it != attrs.end()) {
285-
auto result = std::visit([](auto&& v) { return nb::cast(v.value); },
286-
it->second);
282+
auto name = nb::cast<std::string>(nb::handle(key));
283+
auto value =
284+
device->device_->Attributes().Get<xla::ifrt::AttributeMap::Value>(
285+
name);
286+
if (value.ok()) {
287+
auto result =
288+
std::visit([](auto&& v) { return nb::cast(v.value); }, *value);
287289
return result.release().ptr();
288290
}
289291
PyErr_SetNone(PyExc_AttributeError);

0 commit comments

Comments
 (0)