@@ -389,25 +389,29 @@ def holds(self) -> bool | None:
389389
390390 Returns `None` if the constraint can't be checked.
391391 """
392- source = self .source
393- target = self .target
394392
395- if isinstance (source , TMEMLayout ) and isinstance (target , RegisterLayout ):
396- return self ._is_valid_tmem_transfer (source .value , target .value )
397- if isinstance (target , TMEMLayout ) and isinstance (source , RegisterLayout ):
398- return self ._is_valid_tmem_transfer (target .value , source .value )
399- if isinstance (source , TMEMLayout ) and isinstance (target , TMEMLayout ):
400- return source == target
401- if isinstance (source , SMEMTiling ) and isinstance (target , RegisterLayout ):
402- return self ._is_valid_smem_transfer (source .value , target .value )
403- if isinstance (target , SMEMTiling ) and isinstance (source , RegisterLayout ):
404- return self ._is_valid_smem_transfer (target .value , source .value )
405- if isinstance (target , Constant ) and isinstance (source , Constant ):
406- source_type = type (source ).__name__
407- target_type = type (target ).__name__
408- raise NotImplementedError (f"Unsupported transfer: { source_type } -> { target_type } " )
393+ assert self .source != self .target , (
394+ "IsTransferable constraints within the same memory space are not"
395+ " supported."
396+ )
409397
410- return None
398+ match self .source , self .target :
399+ case TMEMLayout (value = src ), RegisterLayout (value = dst ):
400+ return self ._is_valid_tmem_transfer (src , dst )
401+ case RegisterLayout (value = src ), TMEMLayout (value = dst ):
402+ return self ._is_valid_tmem_transfer (dst , src )
403+ case SMEMTiling (value = src ), RegisterLayout (value = dst ):
404+ return self ._is_valid_smem_transfer (src , dst )
405+ case RegisterLayout (value = src ), SMEMTiling (value = dst ):
406+ return self ._is_valid_smem_transfer (dst , src )
407+ case Constant (), Constant ():
408+ source_type = type (self .source ).__name__
409+ target_type = type (self .target ).__name__
410+ raise NotImplementedError (
411+ f"Unsupported transfer: { source_type } -> { target_type } "
412+ )
413+ case _:
414+ return None
411415
412416 def __str__ (self ):
413417 return f"IsTransferable({ self .source } ⟶ { self .target } )"
0 commit comments