@@ -967,6 +967,116 @@ def combine_fn(carry, x):
967967 self .assertIn ("aten::select_copy" , op_names )
968968 self .assertIn ("executorch_prim::et_copy_index" , op_names )
969969
970+ def test_emit_scan_gru (self ) -> None :
971+ """Test scan with a simple GRU-like computation."""
972+ from torch ._higher_order_ops .scan import scan
973+
974+ class SimpleGRU (torch .nn .Module ):
975+ """Simple single-layer unidirectional GRU using scan."""
976+
977+ def __init__ (self , input_size : int , hidden_size : int ):
978+ super ().__init__ ()
979+ self .input_size = input_size
980+ self .hidden_size = hidden_size
981+
982+ # GRU gates: reset, update, new
983+ self .weight_ih = torch .nn .Parameter (
984+ torch .randn (3 * hidden_size , input_size ), requires_grad = False
985+ )
986+ self .weight_hh = torch .nn .Parameter (
987+ torch .randn (3 * hidden_size , hidden_size ), requires_grad = False
988+ )
989+ self .bias_ih = torch .nn .Parameter (
990+ torch .randn (3 * hidden_size ), requires_grad = False
991+ )
992+ self .bias_hh = torch .nn .Parameter (
993+ torch .randn (3 * hidden_size ), requires_grad = False
994+ )
995+
996+ def forward (
997+ self , x : torch .Tensor , h0 : torch .Tensor
998+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
999+ """
1000+ Args:
1001+ x: Input tensor of shape [seq_len, batch, input_size]
1002+ h0: Initial hidden state of shape [batch, hidden_size]
1003+ Returns:
1004+ output: Output tensor of shape [seq_len, batch, hidden_size]
1005+ h_n: Final hidden state of shape [batch, hidden_size]
1006+ """
1007+ weight_ih = self .weight_ih
1008+ weight_hh = self .weight_hh
1009+ bias_ih = self .bias_ih
1010+ bias_hh = self .bias_hh
1011+
1012+ def gru_cell (
1013+ h : torch .Tensor , x_t : torch .Tensor
1014+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
1015+ # Compute gates
1016+ gates_ih = torch .nn .functional .linear (x_t , weight_ih , bias_ih )
1017+ gates_hh = torch .nn .functional .linear (h , weight_hh , bias_hh )
1018+
1019+ # Split into reset, update, new gates
1020+ r_ih , z_ih , n_ih = gates_ih .chunk (3 , dim = - 1 )
1021+ r_hh , z_hh , n_hh = gates_hh .chunk (3 , dim = - 1 )
1022+
1023+ r = torch .sigmoid (r_ih + r_hh )
1024+ z = torch .sigmoid (z_ih + z_hh )
1025+ n = torch .tanh (n_ih + r * n_hh )
1026+
1027+ h_new = (1 - z ) * n + z * h
1028+ return h_new , h_new .clone ()
1029+
1030+ final_h , outputs = scan (gru_cell , h0 , x )
1031+ return outputs , final_h
1032+
1033+ # Create model and inputs
1034+ input_size = 4
1035+ hidden_size = 8
1036+ seq_len = 5
1037+ batch_size = 2
1038+
1039+ model = SimpleGRU (input_size , hidden_size )
1040+ x = torch .randn (seq_len , batch_size , input_size )
1041+ h0 = torch .randn (batch_size , hidden_size )
1042+ inputs = (x , h0 )
1043+
1044+ # Run through eager PyTorch
1045+ eager_outputs = model (* inputs )
1046+
1047+ # Export and convert to edge
1048+ module = to_edge (
1049+ export (model , inputs , strict = True ),
1050+ compile_config = exir .EdgeCompileConfig (_check_ir_validity = False ),
1051+ )
1052+ et = module .to_executorch ()
1053+ program = et .executorch_program
1054+
1055+ # Verify the program has expected operators
1056+ op_names = [op .name for op in program .execution_plan [0 ].operators ]
1057+
1058+ # Should have scan control flow operators
1059+ self .assertIn ("aten::sym_size" , op_names )
1060+ self .assertIn ("aten::select_copy" , op_names )
1061+ self .assertIn ("executorch_prim::et_copy_index" , op_names )
1062+
1063+ # Verify we can load the program
1064+ buffer = et .buffer
1065+ loaded_model = _load_for_executorch_from_buffer (buffer )
1066+
1067+ # Run through executorch
1068+ et_outputs = loaded_model (inputs )
1069+
1070+ # Compare outputs (with tolerance for floating point)
1071+ self .assertTrue (
1072+ torch .allclose (et_outputs [0 ], eager_outputs [0 ], atol = 1e-5 ),
1073+ f"Output mismatch: { et_outputs [0 ]} vs { eager_outputs [0 ]} " ,
1074+ )
1075+ self .assertTrue (
1076+ torch .allclose (et_outputs [1 ], eager_outputs [1 ], atol = 1e-5 ),
1077+ f"Final hidden state mismatch: { et_outputs [1 ]} vs { eager_outputs [1 ]} " ,
1078+ )
1079+
9701080 def test_dim_order (self ) -> None :
9711081 class SimpleLinear (torch .nn .Module ):
9721082 def __init__ (self ) -> None :
0 commit comments