Skip to content

Commit 40aad67

Browse files
Matt-HurdGoogle-ML-Automation
authored andcommitted
[jax.collect_profile] Allow arbitrary options to be passed to XProf
XProf has a number of underlying profiler options that can be exposed, with more being added. This will limit the number of updates to this script, as well as prevent any incorrect flags from being listed as a result of version differences. PiperOrigin-RevId: 814381310
1 parent 94ae97f commit 40aad67

File tree

1 file changed

+71
-21
lines changed

1 file changed

+71
-21
lines changed

jax/collect_profile.py

Lines changed: 71 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
from typing import Any
1718
import argparse
1819
import gzip
1920
import os
@@ -39,10 +40,16 @@
3940
for a provided duration. The trace file will be dumped into a directory
4041
(determined by `--log_dir`) and by default, a Perfetto UI link will be generated
4142
to view the resulting trace.
43+
44+
Common tracer options (with defaults):
45+
--host_tracer_level=2 Profiler host tracer level.
46+
--device_tracer_level=1 Profiler device tracer level.
47+
--python_tracer_level=1 Profiler Python tracer level.
4248
"""
4349
_GRPC_PREFIX = 'grpc://'
4450
DEFAULT_NUM_TRACING_ATTEMPTS = 3
45-
parser = argparse.ArgumentParser(description=_DESCRIPTION)
51+
parser = argparse.ArgumentParser(description=_DESCRIPTION,
52+
formatter_class=argparse.RawTextHelpFormatter)
4653
parser.add_argument("--log_dir", default=None,
4754
help=("Directory to store log files. "
4855
"Uses a temporary directory if none provided."),
@@ -56,22 +63,22 @@
5663
parser.add_argument("--host", default="127.0.0.1",
5764
help="Host to collect trace. Defaults to 127.0.0.1",
5865
type=str)
59-
parser.add_argument("--host_tracer_level", default=2,
60-
help="Profiler host tracer level", type=int)
61-
parser.add_argument("--device_tracer_level", default=1,
62-
help="Profiler device tracer level", type=int)
63-
parser.add_argument("--python_tracer_level", default=1,
64-
help="Profiler Python tracer level", type=int)
65-
66-
def collect_profile(port: int, duration_in_ms: int, host: str,
67-
log_dir: os.PathLike | str | None, host_tracer_level: int,
68-
device_tracer_level: int, python_tracer_level: int,
69-
no_perfetto_link: bool):
70-
options = {
71-
"host_tracer_level": host_tracer_level,
72-
"device_tracer_level": device_tracer_level,
73-
"python_tracer_level": python_tracer_level,
66+
67+
def collect_profile(
68+
port: int,
69+
duration_in_ms: int,
70+
host: str,
71+
log_dir: os.PathLike | str | None,
72+
no_perfetto_link: bool,
73+
xprof_options: dict[str, Any] | None = None,):
74+
options: dict[str, Any] = {
75+
"host_tracer_level": 2,
76+
"device_tracer_level": 1,
77+
"python_tracer_level": 1,
7478
}
79+
if xprof_options:
80+
options.update(xprof_options)
81+
7582
IS_GCS_PATH = str(log_dir).startswith("gs://")
7683
log_dir_ = pathlib.Path(log_dir if log_dir is not None else tempfile.mkdtemp())
7784
str_log_dir = log_dir if IS_GCS_PATH else str(log_dir_)
@@ -116,10 +123,53 @@ def _strip_prefix(s, prefix):
116123
def _strip_addresses(addresses, prefix):
117124
return ','.join([_strip_prefix(s, prefix) for s in addresses.split(',')])
118125

119-
def main(args):
120-
collect_profile(args.port, args.duration_in_ms, args.host, args.log_dir,
121-
args.host_tracer_level, args.device_tracer_level,
122-
args.python_tracer_level, args.no_perfetto_link)
126+
def _parse_xprof_flags(unknown_flags: list[str]) -> dict[str, Any]:
127+
parsed: dict[str, Any] = {}
128+
i = 0
129+
while i < len(unknown_flags):
130+
arg = unknown_flags[i]
131+
if not arg.startswith('--'):
132+
raise ValueError(f"Unknown positional argument encountered: {arg}")
133+
134+
key = arg[2:]
135+
if "=" in key:
136+
key, value_str = key.split("=", 1)
137+
i += 1
138+
elif i + 1 < len(unknown_flags) and not unknown_flags[i + 1].startswith('--'):
139+
value_str = unknown_flags[i + 1]
140+
i += 2
141+
else:
142+
parsed[key] = True
143+
i += 1
144+
continue
145+
146+
value_lower = value_str.lower()
147+
if value_lower in {'true', 't', 'yes', 'y'}:
148+
parsed[key] = True
149+
elif value_lower in {'false', 'f', 'no', 'n'}:
150+
parsed[key] = False
151+
else:
152+
try:
153+
parsed[key] = int(value_str, 0)
154+
except ValueError:
155+
try:
156+
parsed[key] = float(value_str)
157+
except ValueError:
158+
parsed[key] = value_str # Keep as string
159+
return parsed
160+
161+
162+
def main(known_args, unknown_flags):
163+
xprof_options = _parse_xprof_flags(unknown_flags)
164+
collect_profile(
165+
known_args.port,
166+
known_args.duration_in_ms,
167+
known_args.host,
168+
known_args.log_dir,
169+
known_args.no_perfetto_link,
170+
xprof_options,
171+
)
123172

124173
if __name__ == "__main__":
125-
main(parser.parse_args())
174+
known_args, unknown_flags = parser.parse_known_args()
175+
main(known_args, unknown_flags)

0 commit comments

Comments
 (0)