Skip to content

Commit 91d68b5

Browse files
zuasiajax authors
authored andcommitted
creat jax config api to allow custom pjrt client create option settings. this allows a device platform's pjrt client be aware of the calling (customer) ml framework
PiperOrigin-RevId: 638009713
1 parent db11842 commit 91d68b5

File tree

5 files changed

+59
-15
lines changed

5 files changed

+59
-15
lines changed

jax/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ pytype_strict_library(
389389
name = "cloud_tpu_init",
390390
srcs = ["_src/cloud_tpu_init.py"],
391391
deps = [
392+
":config",
392393
":hardware_utils",
393394
":version",
394395
],

jax/_src/cloud_tpu_init.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
# limitations under the License.
1414

1515
import os
16-
from jax._src import hardware_utils
1716
from jax import version
17+
from jax._src import config
18+
from jax._src import hardware_utils
1819

1920
running_in_cloud_tpu_vm: bool = False
2021

@@ -73,3 +74,9 @@ def cloud_tpu_init() -> None:
7374
# this makes tensorstore serialization work better on TPU
7475
os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS', '60')
7576
os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_LIMIT_BYTES', '256')
77+
78+
if config.jax_pjrt_client_create_options.value is None:
79+
config.update(
80+
'jax_pjrt_client_create_options',
81+
f'ml_framework_name:JAX;ml_framework_version:{version.__version__}'
82+
)

jax/_src/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,12 @@ def update_thread_local_jit_state(**kw):
935935
'otherwise.'
936936
))
937937

938+
jax_pjrt_client_create_options = define_optional_string_state(
939+
name='jax_pjrt_client_create_options',
940+
default=None,
941+
help=('A set of key-value pairs in the format of "k1:v1;k2:v2" strings '
942+
'provided to a device platform pjrt client as extra arguments.'))
943+
938944
enable_checks = define_bool_state(
939945
name='jax_enable_checks',
940946
default=False,

jax/_src/xla_bridge.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from jax._src.lib import cuda_versions
4848
from jax._src.lib import xla_client
4949
from jax._src.lib import xla_extension
50+
from jax._src.lib import xla_extension_version
5051
from jax._src.lib import jaxlib
5152

5253
logger = logging.getLogger(__name__)
@@ -160,7 +161,13 @@ def _log_warning():
160161
t.start()
161162

162163
try:
163-
client = xla_client.make_tpu_client(_get_tpu_library_path())
164+
if xla_extension_version >= 267:
165+
client = xla_client.make_tpu_client( # type: ignore
166+
_get_tpu_library_path(),
167+
_options_from_jax_configs("tpu"))
168+
else:
169+
client = xla_client.make_tpu_client(
170+
_get_tpu_library_path())
164171
finally:
165172
t.cancel()
166173

@@ -618,16 +625,30 @@ def discover_pjrt_plugins() -> None:
618625

619626

620627
def _options_from_jax_configs(plugin_name):
621-
if plugin_name != "cuda":
622-
return {}
623-
624628
options = {}
625-
visible_devices = CUDA_VISIBLE_DEVICES.value
626-
if visible_devices != 'all':
627-
options['visible_devices'] = [int(x) for x in visible_devices.split(',')]
628-
options['enable_mock_nccl'] = _USE_MOCK_GPU_CLIENT.value
629-
if options['enable_mock_nccl']:
630-
options['num_nodes'] = _MOCK_NUM_GPUS.value
629+
630+
pjrt_client_options = config.jax_pjrt_client_create_options.value
631+
pjrt_client_option_list = []
632+
if pjrt_client_options:
633+
pjrt_client_option_list = pjrt_client_options.split(";")
634+
635+
for option in pjrt_client_option_list:
636+
option_list = option.split(":")
637+
if (len(option_list) != 2):
638+
raise RuntimeError(
639+
"Multiple ':' separators for option in "
640+
f"jax_pjrt_client_create_options: '{option}'. "
641+
"Should be in format 'key:value'")
642+
options[option_list[0]] = option_list[1]
643+
644+
if plugin_name == "cuda":
645+
visible_devices = CUDA_VISIBLE_DEVICES.value
646+
if visible_devices != 'all':
647+
options['visible_devices'] = [int(x) for x in visible_devices.split(',')]
648+
options['enable_mock_nccl'] = _USE_MOCK_GPU_CLIENT.value
649+
if options['enable_mock_nccl']:
650+
options['num_nodes'] = _MOCK_NUM_GPUS.value
651+
631652
return options
632653

633654

tests/xla_bridge_test.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from jax._src import xla_bridge as xb
2727
from jax._src.interpreters import xla
2828
from jax._src.lib import xla_client as xc
29+
from jax._src.lib import xla_extension_version
2930

3031
config.parse_flags_with_absl()
3132

@@ -143,7 +144,7 @@ def test_timer_tpu_warning(self):
143144
with warnings.catch_warnings(record=True) as w:
144145
warnings.simplefilter("always")
145146

146-
def _mock_tpu_client(library_path=None):
147+
def _mock_tpu_client_with_options(library_path=None, options=None):
147148
time_to_wait = 5
148149
start = time.time()
149150
while not w:
@@ -157,9 +158,17 @@ def _mock_tpu_client(library_path=None):
157158
msg = str(w[-1].message)
158159
self.assertIn("Did you run your code on all TPU hosts?", msg)
159160

160-
with mock.patch.object(xc, "make_tpu_client",
161-
side_effect=_mock_tpu_client):
162-
xb.tpu_client_timer_callback(0.01)
161+
def _mock_tpu_client(library_path=None):
162+
_mock_tpu_client_with_options(library_path=library_path, options=None)
163+
164+
if xla_extension_version >= 267:
165+
with mock.patch.object(xc, "make_tpu_client",
166+
side_effect=_mock_tpu_client_with_options):
167+
xb.tpu_client_timer_callback(0.01)
168+
else:
169+
with mock.patch.object(xc, "make_tpu_client",
170+
side_effect=_mock_tpu_client):
171+
xb.tpu_client_timer_callback(0.01)
163172

164173
def test_register_plugin(self):
165174
with self.assertLogs(level="WARNING") as log_output:

0 commit comments

Comments
 (0)