Skip to content

Commit d9a7388

Browse files
allanrenucciGoogle-ML-Automation
authored andcommitted
[Mosaic GPU][NFC] Refactor IsTransferable.holds to use pattern matching.
PiperOrigin-RevId: 842626789
1 parent 4952b21 commit d9a7388

File tree

1 file changed

+21
-17
lines changed

1 file changed

+21
-17
lines changed

jax/experimental/mosaic/gpu/constraints.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -389,25 +389,29 @@ def holds(self) -> bool | None:
389389
390390
Returns `None` if the constraint can't be checked.
391391
"""
392-
source = self.source
393-
target = self.target
394392

395-
if isinstance(source, TMEMLayout) and isinstance(target, RegisterLayout):
396-
return self._is_valid_tmem_transfer(source.value, target.value)
397-
if isinstance(target, TMEMLayout) and isinstance(source, RegisterLayout):
398-
return self._is_valid_tmem_transfer(target.value, source.value)
399-
if isinstance(source, TMEMLayout) and isinstance(target, TMEMLayout):
400-
return source == target
401-
if isinstance(source, SMEMTiling) and isinstance(target, RegisterLayout):
402-
return self._is_valid_smem_transfer(source.value, target.value)
403-
if isinstance(target, SMEMTiling) and isinstance(source, RegisterLayout):
404-
return self._is_valid_smem_transfer(target.value, source.value)
405-
if isinstance(target, Constant) and isinstance(source, Constant):
406-
source_type = type(source).__name__
407-
target_type = type(target).__name__
408-
raise NotImplementedError(f"Unsupported transfer: {source_type} -> {target_type}")
393+
assert self.source != self.target, (
394+
"IsTransferable constraints within the same memory space are not"
395+
" supported."
396+
)
409397

410-
return None
398+
match self.source, self.target:
399+
case TMEMLayout(value=src), RegisterLayout(value=dst):
400+
return self._is_valid_tmem_transfer(src, dst)
401+
case RegisterLayout(value=src), TMEMLayout(value=dst):
402+
return self._is_valid_tmem_transfer(dst, src)
403+
case SMEMTiling(value=src), RegisterLayout(value=dst):
404+
return self._is_valid_smem_transfer(src, dst)
405+
case RegisterLayout(value=src), SMEMTiling(value=dst):
406+
return self._is_valid_smem_transfer(dst, src)
407+
case Constant(), Constant():
408+
source_type = type(self.source).__name__
409+
target_type = type(self.target).__name__
410+
raise NotImplementedError(
411+
f"Unsupported transfer: {source_type} -> {target_type}"
412+
)
413+
case _:
414+
return None
411415

412416
def __str__(self):
413417
return f"IsTransferable({self.source}{self.target})"

0 commit comments

Comments
 (0)