@@ -84,7 +84,7 @@ def _stack(arrs: Sequence[Array], axis: int=0) -> Array:
8484
8585def _promote_weak_typed_input (
8686 in_val :Any , in_aval :AbstractValue , out_aval :AbstractValue
87- ) -> tuple [AbstractValue , bool ]:
87+ ) -> tuple [Any , bool ]:
8888 if getattr (in_aval , 'weak_type' , False ) and not core .typematch (in_aval , out_aval ):
8989 new_dtype = dtypes .result_type (in_val , out_aval )
9090 return lax .convert_element_type (in_val , new_dtype ), True
@@ -228,7 +228,7 @@ def scan(f, init, xs, length=None):
228228 return carry , stacked_y
229229
230230 if config .mutable_array_checks .value :
231- check_no_aliased_ref_args (lambda : dbg_body , list (args ), list (args_avals ))
231+ check_no_aliased_ref_args (lambda : dbg_body , list (args_avals ), list (args ))
232232
233233 x_avals = xs_avals .map (lambda aval : core .mapped_aval (length , 0 , aval ))
234234 def _create_jaxpr (carry_avals ):
@@ -252,6 +252,8 @@ def _create_jaxpr(carry_avals):
252252 if config .mutable_array_checks .value :
253253 _check_no_aliased_closed_over_refs (dbg_body , consts , list (args ))
254254 carry_out_avals , ys_avals = out_avals .unpack ()
255+ if len (carry_out_avals ) != len (init_avals ):
256+ _check_carry_type ('scan body' , f , init_avals , carry_out_avals )
255257 init , changed = init .map3 (
256258 _promote_weak_typed_input ,
257259 init_avals , carry_out_avals ).unzip2 ()
@@ -277,7 +279,7 @@ def _create_jaxpr(carry_avals):
277279 if unroll < 0 :
278280 raise ValueError ("`unroll` must be a `bool` or a non-negative `int`." )
279281
280- args_flat = ( * init .vals , * xs .vals )
282+ args_flat = [ * init .vals , * xs .vals ]
281283
282284 # If the body forwards an input carry to an output carry, that input is
283285 # read-only and can be moved to be a const. Doing so can lead to efficiency
@@ -381,18 +383,18 @@ def _check_carry_type(name, body_fun, in_carry, out_carry):
381383 if p else 'the input carry' )
382384 if in_carry .tree != out_carry .tree :
383385 try :
384- out_carry = out_carry .unflatten ()
386+ out_carry_unflat = out_carry .unflatten ()
385387 except :
386- out_carry = None
388+ out_carry_unflat = None
387389
388- if out_carry is None :
390+ if out_carry_unflat is None :
389391 differences = (f'the input tree structure is:\n { in_carry .tree } \n ' +
390392 f'the output tree structure is:\n { out_carry .tree } \n ' )
391393 else :
392394 diffs = [f'{ component (path )} is a { thing1 } but the corresponding component '
393395 f'of the carry output is a { thing2 } , so { explanation } '
394396 for path , thing1 , thing2 , explanation
395- in equality_errors (in_carry , out_carry )]
397+ in equality_errors (in_carry . unflatten () , out_carry . unflatten () )]
396398 if len (diffs ) == 0 :
397399 return # the trees may have different aux data, but structures are same
398400 elif len (diffs ) == 1 :
@@ -1709,7 +1711,7 @@ def _create_jaxpr(init_avals):
17091711
17101712 cond_dbg = api_util .debug_info ("while_cond" , cond_fun , (init_val ,), {})
17111713 body_dbg = api_util .debug_info ("while_body" , body_fun , (init_val ,), {})
1712- init_val = FlatTree .flatten (init_val )
1714+ init_val = FlatTree .flatten (init_val ) # type: ignore
17131715 init_aval = init_val .map (core .get_aval )
17141716
17151717 # The body input and output avals must match exactly. However, we want to account for
@@ -1718,6 +1720,10 @@ def _create_jaxpr(init_avals):
17181720 # To do this, we compute the jaxpr in two passes: first with the raw inputs, and if
17191721 # necessary, a second time with modified init values.
17201722 cond_jaxpr , body_jaxpr , body_out_avals = _create_jaxpr (init_aval )
1723+ if len (body_out_avals ) != len (init_aval ):
1724+ _check_carry_type ('while_loop body' , body_fun , init_aval , body_out_avals )
1725+ assert False , "shouldn't get here"
1726+
17211727 init_val , changed = init_val .map3 (
17221728 _promote_weak_typed_input ,
17231729 init_aval , body_out_avals ).unzip2 ()
@@ -1749,7 +1755,7 @@ def _create_jaxpr(init_avals):
17491755 _ , keep_cond_carry = split_list (keep_cond , [len (cond_consts )])
17501756 move_to_const = _map (operator .not_ , keep_cond_carry )
17511757
1752- init_vals = list (init_val )
1758+ init_vals = list (init_val ) # type: ignore
17531759 if any (move_to_const ):
17541760 cond_jaxpr = pe .close_jaxpr (cond_jaxpr_ )
17551761 body_jaxpr = pe .prune_closed_jaxpr_outputs (
0 commit comments