Skip to content

Commit 45eee87

Browse files
committed
undo unneeded change
1 parent f965244 commit 45eee87

File tree

1 file changed

+0
-63
lines changed

1 file changed

+0
-63
lines changed

exir/tests/control_flow_models.py

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -103,66 +103,3 @@ def get_upper_bound_inputs(self):
103103

104104
def get_random_inputs(self):
105105
return torch.rand(2, 4), torch.rand(4)
106-
107-
108-
class FTScanBasic(Module):
109-
"""Basic scan model that computes cumulative sum."""
110-
111-
def __init__(self):
112-
super().__init__()
113-
114-
def forward(self, xs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
115-
def combine_fn(carry, x):
116-
new_carry = carry + x
117-
y = new_carry.clone()
118-
return new_carry, y
119-
120-
init = torch.zeros_like(xs[0])
121-
return torch.scan(combine_fn, init, xs)
122-
123-
def get_random_inputs(self):
124-
return (torch.arange(5).float().unsqueeze(1).expand(5, 3),)
125-
126-
127-
class FTScanMultipleCarry(Module):
128-
"""Scan model with multiple carry values (sum and product)."""
129-
130-
def __init__(self):
131-
super().__init__()
132-
133-
def forward(
134-
self, xs: torch.Tensor
135-
) -> tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
136-
def combine_fn(carry, x):
137-
sum_carry, prod_carry = carry
138-
new_sum = sum_carry + x
139-
new_prod = prod_carry * x
140-
y = new_sum + new_prod
141-
return (new_sum, new_prod), y.clone()
142-
143-
init_sum = torch.zeros_like(xs[0])
144-
init_prod = torch.ones_like(xs[0])
145-
return torch.scan(combine_fn, (init_sum, init_prod), xs)
146-
147-
def get_random_inputs(self):
148-
return (torch.arange(1, 5).float().unsqueeze(1).expand(4, 2),)
149-
150-
151-
class FTScanWithAdditionalInputs(Module):
152-
"""Scan model with additional inputs (closure-like behavior)."""
153-
154-
def __init__(self):
155-
super().__init__()
156-
157-
def forward(
158-
self, xs: torch.Tensor, scale: torch.Tensor
159-
) -> tuple[torch.Tensor, torch.Tensor]:
160-
def combine_fn(carry, x):
161-
new_carry = carry + x * scale
162-
return new_carry, new_carry.clone()
163-
164-
init = torch.zeros_like(xs[0])
165-
return torch.scan(combine_fn, init, xs)
166-
167-
def get_random_inputs(self):
168-
return (torch.arange(5).float().unsqueeze(1).expand(5, 3), torch.tensor([2.0]))

0 commit comments

Comments
 (0)