@@ -103,66 +103,3 @@ def get_upper_bound_inputs(self):
103103
104104 def get_random_inputs (self ):
105105 return torch .rand (2 , 4 ), torch .rand (4 )
106-
107-
108- class FTScanBasic (Module ):
109- """Basic scan model that computes cumulative sum."""
110-
111- def __init__ (self ):
112- super ().__init__ ()
113-
114- def forward (self , xs : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
115- def combine_fn (carry , x ):
116- new_carry = carry + x
117- y = new_carry .clone ()
118- return new_carry , y
119-
120- init = torch .zeros_like (xs [0 ])
121- return torch .scan (combine_fn , init , xs )
122-
123- def get_random_inputs (self ):
124- return (torch .arange (5 ).float ().unsqueeze (1 ).expand (5 , 3 ),)
125-
126-
127- class FTScanMultipleCarry (Module ):
128- """Scan model with multiple carry values (sum and product)."""
129-
130- def __init__ (self ):
131- super ().__init__ ()
132-
133- def forward (
134- self , xs : torch .Tensor
135- ) -> tuple [tuple [torch .Tensor , torch .Tensor ], torch .Tensor ]:
136- def combine_fn (carry , x ):
137- sum_carry , prod_carry = carry
138- new_sum = sum_carry + x
139- new_prod = prod_carry * x
140- y = new_sum + new_prod
141- return (new_sum , new_prod ), y .clone ()
142-
143- init_sum = torch .zeros_like (xs [0 ])
144- init_prod = torch .ones_like (xs [0 ])
145- return torch .scan (combine_fn , (init_sum , init_prod ), xs )
146-
147- def get_random_inputs (self ):
148- return (torch .arange (1 , 5 ).float ().unsqueeze (1 ).expand (4 , 2 ),)
149-
150-
151- class FTScanWithAdditionalInputs (Module ):
152- """Scan model with additional inputs (closure-like behavior)."""
153-
154- def __init__ (self ):
155- super ().__init__ ()
156-
157- def forward (
158- self , xs : torch .Tensor , scale : torch .Tensor
159- ) -> tuple [torch .Tensor , torch .Tensor ]:
160- def combine_fn (carry , x ):
161- new_carry = carry + x * scale
162- return new_carry , new_carry .clone ()
163-
164- init = torch .zeros_like (xs [0 ])
165- return torch .scan (combine_fn , init , xs )
166-
167- def get_random_inputs (self ):
168- return (torch .arange (5 ).float ().unsqueeze (1 ).expand (5 , 3 ), torch .tensor ([2.0 ]))
0 commit comments