|
14 | 14 |
|
15 | 15 | from __future__ import annotations |
16 | 16 |
|
| 17 | +from typing import Any |
17 | 18 | import argparse |
18 | 19 | import gzip |
19 | 20 | import os |
|
39 | 40 | for a provided duration. The trace file will be dumped into a directory |
40 | 41 | (determined by `--log_dir`) and by default, a Perfetto UI link will be generated |
41 | 42 | to view the resulting trace. |
| 43 | + |
| 44 | +Common tracer options (with defaults): |
| 45 | + --host_tracer_level=2 Profiler host tracer level. |
| 46 | + --device_tracer_level=1 Profiler device tracer level. |
| 47 | + --python_tracer_level=1 Profiler Python tracer level. |
42 | 48 | """ |
43 | 49 | _GRPC_PREFIX = 'grpc://' |
44 | 50 | DEFAULT_NUM_TRACING_ATTEMPTS = 3 |
45 | | -parser = argparse.ArgumentParser(description=_DESCRIPTION) |
| 51 | +parser = argparse.ArgumentParser(description=_DESCRIPTION, |
| 52 | + formatter_class=argparse.RawTextHelpFormatter) |
46 | 53 | parser.add_argument("--log_dir", default=None, |
47 | 54 | help=("Directory to store log files. " |
48 | 55 | "Uses a temporary directory if none provided."), |
|
56 | 63 | parser.add_argument("--host", default="127.0.0.1", |
57 | 64 | help="Host to collect trace. Defaults to 127.0.0.1", |
58 | 65 | type=str) |
59 | | -parser.add_argument("--host_tracer_level", default=2, |
60 | | - help="Profiler host tracer level", type=int) |
61 | | -parser.add_argument("--device_tracer_level", default=1, |
62 | | - help="Profiler device tracer level", type=int) |
63 | | -parser.add_argument("--python_tracer_level", default=1, |
64 | | - help="Profiler Python tracer level", type=int) |
65 | | - |
66 | | -def collect_profile(port: int, duration_in_ms: int, host: str, |
67 | | - log_dir: os.PathLike | str | None, host_tracer_level: int, |
68 | | - device_tracer_level: int, python_tracer_level: int, |
69 | | - no_perfetto_link: bool): |
70 | | - options = { |
71 | | - "host_tracer_level": host_tracer_level, |
72 | | - "device_tracer_level": device_tracer_level, |
73 | | - "python_tracer_level": python_tracer_level, |
| 66 | + |
| 67 | +def collect_profile( |
| 68 | + port: int, |
| 69 | + duration_in_ms: int, |
| 70 | + host: str, |
| 71 | + log_dir: os.PathLike | str | None, |
| 72 | + no_perfetto_link: bool, |
| 73 | + xprof_options: dict[str, Any] | None = None,): |
| 74 | + options: dict[str, Any] = { |
| 75 | + "host_tracer_level": 2, |
| 76 | + "device_tracer_level": 1, |
| 77 | + "python_tracer_level": 1, |
74 | 78 | } |
| 79 | + if xprof_options: |
| 80 | + options.update(xprof_options) |
| 81 | + |
75 | 82 | IS_GCS_PATH = str(log_dir).startswith("gs://") |
76 | 83 | log_dir_ = pathlib.Path(log_dir if log_dir is not None else tempfile.mkdtemp()) |
77 | 84 | str_log_dir = log_dir if IS_GCS_PATH else str(log_dir_) |
@@ -116,10 +123,53 @@ def _strip_prefix(s, prefix): |
116 | 123 | def _strip_addresses(addresses, prefix): |
117 | 124 | return ','.join([_strip_prefix(s, prefix) for s in addresses.split(',')]) |
118 | 125 |
|
119 | | -def main(args): |
120 | | - collect_profile(args.port, args.duration_in_ms, args.host, args.log_dir, |
121 | | - args.host_tracer_level, args.device_tracer_level, |
122 | | - args.python_tracer_level, args.no_perfetto_link) |
| 126 | +def _parse_xprof_flags(unknown_flags: list[str]) -> dict[str, Any]: |
| 127 | + parsed: dict[str, Any] = {} |
| 128 | + i = 0 |
| 129 | + while i < len(unknown_flags): |
| 130 | + arg = unknown_flags[i] |
| 131 | + if not arg.startswith('--'): |
| 132 | + raise ValueError(f"Unknown positional argument encountered: {arg}") |
| 133 | + |
| 134 | + key = arg[2:] |
| 135 | + if "=" in key: |
| 136 | + key, value_str = key.split("=", 1) |
| 137 | + i += 1 |
| 138 | + elif i + 1 < len(unknown_flags) and not unknown_flags[i + 1].startswith('--'): |
| 139 | + value_str = unknown_flags[i + 1] |
| 140 | + i += 2 |
| 141 | + else: |
| 142 | + parsed[key] = True |
| 143 | + i += 1 |
| 144 | + continue |
| 145 | + |
| 146 | + value_lower = value_str.lower() |
| 147 | + if value_lower in {'true', 't', 'yes', 'y'}: |
| 148 | + parsed[key] = True |
| 149 | + elif value_lower in {'false', 'f', 'no', 'n'}: |
| 150 | + parsed[key] = False |
| 151 | + else: |
| 152 | + try: |
| 153 | + parsed[key] = int(value_str, 0) |
| 154 | + except ValueError: |
| 155 | + try: |
| 156 | + parsed[key] = float(value_str) |
| 157 | + except ValueError: |
| 158 | + parsed[key] = value_str # Keep as string |
| 159 | + return parsed |
| 160 | + |
| 161 | + |
| 162 | +def main(known_args, unknown_flags): |
| 163 | + xprof_options = _parse_xprof_flags(unknown_flags) |
| 164 | + collect_profile( |
| 165 | + known_args.port, |
| 166 | + known_args.duration_in_ms, |
| 167 | + known_args.host, |
| 168 | + known_args.log_dir, |
| 169 | + known_args.no_perfetto_link, |
| 170 | + xprof_options, |
| 171 | + ) |
123 | 172 |
|
124 | 173 | if __name__ == "__main__": |
125 | | - main(parser.parse_args()) |
| 174 | + known_args, unknown_flags = parser.parse_known_args() |
| 175 | + main(known_args, unknown_flags) |
0 commit comments