Skip to content

Commit 26f8269

Browse files
committed
JIT: Remove redundant branches to jump in the assembly optimizer
* Refactor JIT assembly optimizer making instructions instances not just strings * Remove redundant jumps and branches where legal to do so
1 parent 918a9ac commit 26f8269

File tree

2 files changed

+158
-61
lines changed

2 files changed

+158
-61
lines changed

Tools/jit/_optimizers.py

Lines changed: 157 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Low-level optimization of textual assembly."""
22

33
import dataclasses
4+
import enum
45
import pathlib
56
import re
67
import typing
@@ -65,23 +66,70 @@
6566
# MyPy doesn't understand that a invariant variable can be initialized by a covariant value
6667
CUSTOM_AARCH64_BRANCH19: str | None = "CUSTOM_AARCH64_BRANCH19"
6768

68-
# Branches are either b.{cond} or bc.{cond}
69-
_AARCH64_BRANCHES: dict[str, tuple[str | None, str | None]] = {
70-
"b." + cond: (("b." + inverse if inverse else None), CUSTOM_AARCH64_BRANCH19)
71-
for (cond, inverse) in _AARCH64_COND_CODES.items()
72-
} | {
73-
"bc." + cond: (("bc." + inverse if inverse else None), CUSTOM_AARCH64_BRANCH19)
74-
for (cond, inverse) in _AARCH64_COND_CODES.items()
69+
_AARCH64_SHORT_BRANCHES = {
70+
"cbz": "cbnz",
71+
"cbnz": "cbz",
72+
"tbz": "tbnz",
73+
"tbnz": "tbz",
7574
}
7675

76+
# Branches are either b.{cond} or bc.{cond}
77+
_AARCH64_BRANCHES: dict[str, tuple[str | None, str | None]] = (
78+
{
79+
"b." + cond: (("b." + inverse if inverse else None), CUSTOM_AARCH64_BRANCH19)
80+
for (cond, inverse) in _AARCH64_COND_CODES.items()
81+
}
82+
| {
83+
"bc." + cond: (("bc." + inverse if inverse else None), CUSTOM_AARCH64_BRANCH19)
84+
for (cond, inverse) in _AARCH64_COND_CODES.items()
85+
}
86+
| {cond: (inverse, None) for (cond, inverse) in _AARCH64_SHORT_BRANCHES.items()}
87+
)
88+
89+
90+
@enum.unique
91+
class InstructionKind(enum.Enum):
92+
93+
JUMP = enum.auto()
94+
LONG_BRANCH = enum.auto()
95+
SHORT_BRANCH = enum.auto()
96+
RETURN = enum.auto()
97+
OTHER = enum.auto()
98+
99+
100+
@dataclasses.dataclass
101+
class Instruction:
102+
kind: InstructionKind
103+
name: str
104+
text: str
105+
target: str | None
106+
107+
def is_branch(self) -> bool:
108+
return self.kind in (InstructionKind.LONG_BRANCH, InstructionKind.SHORT_BRANCH)
109+
110+
def update_target(self, target: str) -> "Instruction":
111+
assert self.target is not None
112+
return Instruction(
113+
self.kind, self.name, self.text.replace(self.target, target), target
114+
)
115+
116+
def update_name_and_target(self, name: str, target: str) -> "Instruction":
117+
assert self.target is not None
118+
return Instruction(
119+
self.kind,
120+
name,
121+
self.text.replace(self.name, name).replace(self.target, target),
122+
target,
123+
)
124+
77125

78126
@dataclasses.dataclass
79127
class _Block:
80128
label: str | None = None
81129
# Non-instruction lines like labels, directives, and comments:
82130
noninstructions: list[str] = dataclasses.field(default_factory=list)
83131
# Instruction lines:
84-
instructions: list[str] = dataclasses.field(default_factory=list)
132+
instructions: list[Instruction] = dataclasses.field(default_factory=list)
85133
# If this block ends in a jump, where to?
86134
target: typing.Self | None = None
87135
# The next block in the linked list:
@@ -122,6 +170,9 @@ class Optimizer:
122170
# Override everything that follows in subclasses:
123171
_supports_external_relocations = True
124172
_branches: typing.ClassVar[dict[str, tuple[str | None, str | None]]] = {}
173+
# Short branches are instructions that can branch within a micro-op,
174+
# but might not have the reach to branch anywhere within a trace.
175+
_short_branches: typing.ClassVar[dict[str, str]] = {}
125176
# Two groups (instruction and target):
126177
_re_branch: typing.ClassVar[re.Pattern[str]] = _RE_NEVER_MATCH
127178
# One group (target):
@@ -152,16 +203,19 @@ def __post_init__(self) -> None:
152203
if block.target or not block.fallthrough:
153204
# Current block ends with a branch, jump, or return. New block:
154205
block.link = block = _Block()
155-
block.instructions.append(line)
156-
if match := self._re_branch.match(line):
206+
inst = self._parse_instruction(line)
207+
block.instructions.append(inst)
208+
if inst.is_branch():
157209
# A block ending in a branch has a target and fallthrough:
158-
block.target = self._lookup_label(match["target"])
210+
assert inst.target is not None
211+
block.target = self._lookup_label(inst.target)
159212
assert block.fallthrough
160-
elif match := self._re_jump.match(line):
213+
elif inst.kind == InstructionKind.JUMP:
161214
# A block ending in a jump has a target and no fallthrough:
162-
block.target = self._lookup_label(match["target"])
215+
assert inst.target is not None
216+
block.target = self._lookup_label(inst.target)
163217
block.fallthrough = False
164-
elif self._re_return.match(line):
218+
elif inst.kind == InstructionKind.RETURN:
165219
# A block ending in a return has no target and fallthrough:
166220
assert not block.target
167221
block.fallthrough = False
@@ -174,39 +228,47 @@ def _preprocess(self, text: str) -> str:
174228
continue_label = f"{self.label_prefix}_JIT_CONTINUE"
175229
return re.sub(continue_symbol, continue_label, text)
176230

177-
@classmethod
178-
def _invert_branch(cls, line: str, target: str) -> str | None:
179-
match = cls._re_branch.match(line)
180-
assert match
181-
inverted_reloc = cls._branches.get(match["instruction"])
231+
def _parse_instruction(self, line: str) -> Instruction:
232+
target = None
233+
if match := self._re_branch.match(line):
234+
target = match["target"]
235+
name = match["instruction"]
236+
if name in self._short_branches:
237+
kind = InstructionKind.SHORT_BRANCH
238+
else:
239+
kind = InstructionKind.LONG_BRANCH
240+
elif match := self._re_jump.match(line):
241+
target = match["target"]
242+
name = line[: -len(target)].strip()
243+
kind = InstructionKind.JUMP
244+
elif match := self._re_return.match(line):
245+
name = line
246+
kind = InstructionKind.RETURN
247+
else:
248+
name, *_ = line.split(" ")
249+
kind = InstructionKind.OTHER
250+
return Instruction(kind, name, line, target)
251+
252+
def _invert_branch(self, inst: Instruction, target: str) -> Instruction | None:
253+
assert inst.is_branch()
254+
if inst.kind == InstructionKind.SHORT_BRANCH and self._is_far_target(target):
255+
return None
256+
inverted_reloc = self._branches.get(inst.name)
182257
if inverted_reloc is None:
183258
return None
184259
inverted = inverted_reloc[0]
185260
if not inverted:
186261
return None
187-
(a, b), (c, d) = match.span("instruction"), match.span("target")
188-
# Before:
189-
# je FOO
190-
# After:
191-
# jne BAR
192-
return "".join([line[:a], inverted, line[b:c], target, line[d:]])
193-
194-
@classmethod
195-
def _update_jump(cls, line: str, target: str) -> str:
196-
match = cls._re_jump.match(line)
197-
assert match
198-
a, b = match.span("target")
199-
# Before:
200-
# jmp FOO
201-
# After:
202-
# jmp BAR
203-
return "".join([line[:a], target, line[b:]])
262+
return inst.update_name_and_target(inverted, target)
204263

205264
def _lookup_label(self, label: str) -> _Block:
206265
if label not in self._labels:
207266
self._labels[label] = _Block(label)
208267
return self._labels[label]
209268

269+
def _is_far_target(self, label: str) -> bool:
270+
return not label.startswith(self.label_prefix)
271+
210272
def _blocks(self) -> typing.Generator[_Block, None, None]:
211273
block: _Block | None = self._root
212274
while block:
@@ -222,7 +284,8 @@ def _body(self) -> str:
222284
# Make it easy to tell at a glance where cold code is:
223285
lines.append(f"# JIT: {'HOT' if hot else 'COLD'} ".ljust(80, "#"))
224286
lines.extend(block.noninstructions)
225-
lines.extend(block.instructions)
287+
for inst in block.instructions:
288+
lines.append(inst.text)
226289
return "\n".join(lines)
227290

228291
def _predecessors(self, block: _Block) -> typing.Generator[_Block, None, None]:
@@ -289,8 +352,8 @@ def _invert_hot_branches(self) -> None:
289352
if inverted is None:
290353
continue
291354
branch.instructions[-1] = inverted
292-
jump.instructions[-1] = self._update_jump(
293-
jump.instructions[-1], branch.target.label
355+
jump.instructions[-1] = jump.instructions[-1].update_target(
356+
branch.target.label
294357
)
295358
branch.target, jump.target = jump.target, branch.target
296359
jump.hot = True
@@ -299,49 +362,81 @@ def _remove_redundant_jumps(self) -> None:
299362
# Zero-length jumps can be introduced by _insert_continue_label and
300363
# _invert_hot_branches:
301364
for block in self._blocks():
365+
target = block.target
366+
if target is None:
367+
continue
368+
target = target.resolve()
302369
# Before:
303370
# jmp FOO
304371
# FOO:
305372
# After:
306373
# FOO:
307-
if (
308-
block.target
309-
and block.link
310-
and block.target.resolve() is block.link.resolve()
311-
):
374+
if block.link and target is block.link.resolve():
312375
block.target = None
313376
block.fallthrough = True
314377
block.instructions.pop()
378+
# Before:
379+
# br ? FOO:
380+
# ...
381+
# FOO:
382+
# jump BAR
383+
# After:
384+
# br cond BAR
385+
# ...
386+
elif (
387+
len(target.instructions) == 1
388+
and target.instructions[0].kind == InstructionKind.JUMP
389+
):
390+
assert target.target is not None
391+
assert target.target.label is not None
392+
if block.instructions[
393+
-1
394+
].kind == InstructionKind.SHORT_BRANCH and self._is_far_target(
395+
target.target.label
396+
):
397+
continue
398+
block.target = target.target
399+
block.instructions[-1] = block.instructions[-1].update_target(
400+
target.target.label
401+
)
402+
403+
def _remove_unreachable(self) -> None:
404+
prev: _Block | None = None
405+
for block in self._blocks():
406+
if not list(self._predecessors(block)) and prev:
407+
prev.link = block.link
408+
else:
409+
prev = block
315410

316411
def _fixup_external_labels(self) -> None:
317412
if self._supports_external_relocations:
318413
# Nothing to fix up
319414
return
320-
for block in self._blocks():
415+
for index, block in enumerate(self._blocks()):
321416
if block.target and block.fallthrough:
322417
branch = block.instructions[-1]
323-
match = self._re_branch.match(branch)
324-
assert match is not None
325-
target = match["target"]
326-
reloc = self._branches[match["instruction"]][1]
327-
if reloc is not None and not target.startswith(self.label_prefix):
418+
assert branch.is_branch()
419+
target = branch.target
420+
assert target is not None
421+
reloc = self._branches[branch.name][1]
422+
if reloc is not None and self._is_far_target(target):
328423
name = target[len(self.symbol_prefix) :]
329-
block.instructions[-1] = (
330-
f"// target='{target}' prefix='{self.label_prefix}'"
331-
)
332-
block.instructions.append(
333-
f"{self.symbol_prefix}{reloc}_JIT_RELOCATION_{name}:"
424+
label = f"{self.symbol_prefix}{reloc}_JIT_RELOCATION_{name}_JIT_RELOCATION_{index}:"
425+
block.instructions[-1] = Instruction(
426+
InstructionKind.OTHER, "", label, None
334427
)
335-
a, b = match.span("target")
336-
branch = "".join([branch[:a], "0", branch[b:]])
337-
block.instructions.append(branch)
428+
block.instructions.append(branch.update_target("0"))
338429

339430
def run(self) -> None:
340431
"""Run this optimizer."""
341432
self._insert_continue_label()
342433
self._mark_hot_blocks()
343-
self._invert_hot_branches()
344-
self._remove_redundant_jumps()
434+
# Removing branches can expose opportunities for more branch removal
435+
# Repeat a few times. 2 would probably do, but it's fast enough with 4.
436+
for _ in range(4):
437+
self._invert_hot_branches()
438+
self._remove_redundant_jumps()
439+
self._remove_unreachable()
345440
self._fixup_external_labels()
346441
self.path.write_text(self._body())
347442

@@ -350,6 +445,7 @@ class OptimizerAArch64(Optimizer): # pylint: disable = too-few-public-methods
350445
"""aarch64-pc-windows-msvc/aarch64-apple-darwin/aarch64-unknown-linux-gnu"""
351446

352447
_branches = _AARCH64_BRANCHES
448+
_short_branches = _AARCH64_SHORT_BRANCHES
353449
# Mach-O does not support the 19 bit branch locations needed for branch reordering
354450
_supports_external_relocations = False
355451
_re_branch = re.compile(
@@ -366,6 +462,7 @@ class OptimizerX86(Optimizer): # pylint: disable = too-few-public-methods
366462
"""i686-pc-windows-msvc/x86_64-apple-darwin/x86_64-unknown-linux-gnu"""
367463

368464
_branches = _X86_BRANCHES
465+
_short_branches = {}
369466
_re_branch = re.compile(
370467
rf"\s*(?P<instruction>{'|'.join(_X86_BRANCHES)})\s+(?P<target>[\w.]+)"
371468
)

Tools/jit/_stencils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def convert_labels_to_relocations(self) -> None:
226226
for name, hole_plus in self.symbols.items():
227227
if isinstance(name, str) and "_JIT_RELOCATION_" in name:
228228
_, offset = hole_plus
229-
reloc, target = name.split("_JIT_RELOCATION_")
229+
reloc, target, _ = name.split("_JIT_RELOCATION_")
230230
value, symbol = symbol_to_value(target)
231231
hole = Hole(
232232
int(offset), typing.cast(_schema.HoleKind, reloc), value, symbol, 0

0 commit comments

Comments
 (0)