Skip to content

Commit d563028

Browse files
committed
test
1 parent d1473c7 commit d563028

File tree

1 file changed

+110
-0
lines changed

1 file changed

+110
-0
lines changed

exir/emit/test/test_emit.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -967,6 +967,116 @@ def combine_fn(carry, x):
967967
self.assertIn("aten::select_copy", op_names)
968968
self.assertIn("executorch_prim::et_copy_index", op_names)
969969

970+
def test_emit_scan_gru(self) -> None:
971+
"""Test scan with a simple GRU-like computation."""
972+
from torch._higher_order_ops.scan import scan
973+
974+
class SimpleGRU(torch.nn.Module):
975+
"""Simple single-layer unidirectional GRU using scan."""
976+
977+
def __init__(self, input_size: int, hidden_size: int):
978+
super().__init__()
979+
self.input_size = input_size
980+
self.hidden_size = hidden_size
981+
982+
# GRU gates: reset, update, new
983+
self.weight_ih = torch.nn.Parameter(
984+
torch.randn(3 * hidden_size, input_size), requires_grad=False
985+
)
986+
self.weight_hh = torch.nn.Parameter(
987+
torch.randn(3 * hidden_size, hidden_size), requires_grad=False
988+
)
989+
self.bias_ih = torch.nn.Parameter(
990+
torch.randn(3 * hidden_size), requires_grad=False
991+
)
992+
self.bias_hh = torch.nn.Parameter(
993+
torch.randn(3 * hidden_size), requires_grad=False
994+
)
995+
996+
def forward(
997+
self, x: torch.Tensor, h0: torch.Tensor
998+
) -> Tuple[torch.Tensor, torch.Tensor]:
999+
"""
1000+
Args:
1001+
x: Input tensor of shape [seq_len, batch, input_size]
1002+
h0: Initial hidden state of shape [batch, hidden_size]
1003+
Returns:
1004+
output: Output tensor of shape [seq_len, batch, hidden_size]
1005+
h_n: Final hidden state of shape [batch, hidden_size]
1006+
"""
1007+
weight_ih = self.weight_ih
1008+
weight_hh = self.weight_hh
1009+
bias_ih = self.bias_ih
1010+
bias_hh = self.bias_hh
1011+
1012+
def gru_cell(
1013+
h: torch.Tensor, x_t: torch.Tensor
1014+
) -> Tuple[torch.Tensor, torch.Tensor]:
1015+
# Compute gates
1016+
gates_ih = torch.nn.functional.linear(x_t, weight_ih, bias_ih)
1017+
gates_hh = torch.nn.functional.linear(h, weight_hh, bias_hh)
1018+
1019+
# Split into reset, update, new gates
1020+
r_ih, z_ih, n_ih = gates_ih.chunk(3, dim=-1)
1021+
r_hh, z_hh, n_hh = gates_hh.chunk(3, dim=-1)
1022+
1023+
r = torch.sigmoid(r_ih + r_hh)
1024+
z = torch.sigmoid(z_ih + z_hh)
1025+
n = torch.tanh(n_ih + r * n_hh)
1026+
1027+
h_new = (1 - z) * n + z * h
1028+
return h_new, h_new.clone()
1029+
1030+
final_h, outputs = scan(gru_cell, h0, x)
1031+
return outputs, final_h
1032+
1033+
# Create model and inputs
1034+
input_size = 4
1035+
hidden_size = 8
1036+
seq_len = 5
1037+
batch_size = 2
1038+
1039+
model = SimpleGRU(input_size, hidden_size)
1040+
x = torch.randn(seq_len, batch_size, input_size)
1041+
h0 = torch.randn(batch_size, hidden_size)
1042+
inputs = (x, h0)
1043+
1044+
# Run through eager PyTorch
1045+
eager_outputs = model(*inputs)
1046+
1047+
# Export and convert to edge
1048+
module = to_edge(
1049+
export(model, inputs, strict=True),
1050+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
1051+
)
1052+
et = module.to_executorch()
1053+
program = et.executorch_program
1054+
1055+
# Verify the program has expected operators
1056+
op_names = [op.name for op in program.execution_plan[0].operators]
1057+
1058+
# Should have scan control flow operators
1059+
self.assertIn("aten::sym_size", op_names)
1060+
self.assertIn("aten::select_copy", op_names)
1061+
self.assertIn("executorch_prim::et_copy_index", op_names)
1062+
1063+
# Verify we can load the program
1064+
buffer = et.buffer
1065+
loaded_model = _load_for_executorch_from_buffer(buffer)
1066+
1067+
# Run through executorch
1068+
et_outputs = loaded_model(inputs)
1069+
1070+
# Compare outputs (with tolerance for floating point)
1071+
self.assertTrue(
1072+
torch.allclose(et_outputs[0], eager_outputs[0], atol=1e-5),
1073+
f"Output mismatch: {et_outputs[0]} vs {eager_outputs[0]}",
1074+
)
1075+
self.assertTrue(
1076+
torch.allclose(et_outputs[1], eager_outputs[1], atol=1e-5),
1077+
f"Final hidden state mismatch: {et_outputs[1]} vs {eager_outputs[1]}",
1078+
)
1079+
9701080
def test_dim_order(self) -> None:
9711081
class SimpleLinear(torch.nn.Module):
9721082
def __init__(self) -> None:

0 commit comments

Comments
 (0)