@@ -886,30 +886,6 @@ def combine_fn(carry, x):
886886 # 8. Verify we have the body operations (add from combine_fn)
887887 self .assertIn ("aten::add" , op_names , "Should have add from combine_fn body" )
888888
889- def test_load_emit_scan (self ) -> None :
890- """Test that scan program can be loaded by the runtime."""
891- from torch ._higher_order_ops .scan import scan
892-
893- class ScanCumSum (torch .nn .Module ):
894- def forward (self , xs : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
895- def combine_fn (carry , x ):
896- new_carry = carry + x
897- return new_carry , new_carry .clone ()
898-
899- init = torch .zeros_like (xs [0 ])
900- return scan (combine_fn , init , xs )
901-
902- f = ScanCumSum ()
903- # Use contiguous tensor to avoid stride=0 issue
904- inputs = (torch .arange (15 ).float ().reshape (5 , 3 ),)
905-
906- module = to_edge (
907- export (f , inputs , strict = True ),
908- compile_config = exir .EdgeCompileConfig (_check_ir_validity = False ),
909- )
910- # This should not raise - verifies the program is loadable
911- _load_for_executorch_from_buffer (module .to_executorch ().buffer )
912-
913889 def test_run_emit_scan_cumsum (self ) -> None :
914890 """Test scan execution correctness: cumulative sum."""
915891 from torch ._higher_order_ops .scan import scan
0 commit comments