@@ -392,13 +392,17 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
392392 consts , carry , xs_ = split_list (args , [num_consts , num_carry ])
393393 _ , y_avals = split_list (jaxpr .out_avals , [num_carry ])
394394 num_trips , remainder = divmod (length , unroll )
395- if remainder :
396- if not reverse :
397- xs_ , xs_rem = unzip2 (_map (partial (_split_leading , num_trips * unroll ), xs_ ))
398- else :
399- xs_rem , xs_ = unzip2 (_map (partial (_split_leading , remainder ), xs_ ))
400- xss = [lax .reshape (x , (num_trips , unroll , * x .shape [1 :])) for x in xs_ ]
401- yss = _map (partial (_empty_array , (num_trips , unroll )), y_avals )
395+ if unroll == 1 :
396+ xss = xs_
397+ yss = _map (partial (_empty_array , (length ,)), y_avals )
398+ else :
399+ if remainder :
400+ if not reverse :
401+ xs_ , xs_rem = unzip2 (_map (partial (_split_leading , num_trips * unroll ), xs_ ))
402+ else :
403+ xs_rem , xs_ = unzip2 (_map (partial (_split_leading , remainder ), xs_ ))
404+ xss = [lax .reshape (x , (num_trips , unroll , * x .shape [1 :])) for x in xs_ ]
405+ yss = _map (partial (_empty_array , (num_trips , unroll )), y_avals )
402406
403407 def cond_fun (while_carry ):
404408 i , _ , _ = while_carry
@@ -413,6 +417,9 @@ def body_fun(while_carry):
413417 return i_ + 1 , carry , yss
414418 def inner (n , carry , xs ):
415419 ys = []
420+ if unroll == 1 :
421+ carry_y = eval_jaxpr_p .bind (* consts , * carry , * xs , jaxpr = jaxpr )
422+ return split_list (carry_y , [num_carry ])
416423 for i_ in range (n ):
417424 i = n - i_ - 1 if reverse else i_
418425 x = [slicing .index_in_dim (x , i , keepdims = False ) for x in xs ]
@@ -425,7 +432,10 @@ def inner(n, carry, xs):
425432 if num_trips :
426433 i = lax ._const (num_trips , 0 )
427434 _ , carry , yss = jax .lax .while_loop (cond_fun , body_fun , (i , carry , yss ))
428- ys = [lax .reshape (ys , (num_trips * unroll , * ys .shape [2 :])) for ys in yss ]
435+ if unroll != 1 :
436+ ys = [lax .reshape (ys , (num_trips * unroll , * ys .shape [2 :])) for ys in yss ]
437+ else :
438+ ys = yss
429439 if remainder :
430440 carry , ys_rem = inner (remainder , carry , xs_rem )
431441 ys = _map (_concat , ys , ys_rem ) if not reverse else _map (_concat , ys_rem , ys )
0 commit comments