Skip to content

Commit a33d837

Browse files
committed
Remove type parameters. Type checker doesn't like them
1 parent 75e7243 commit a33d837

File tree

3 files changed

+34
-27
lines changed

3 files changed

+34
-27
lines changed

jax/_src/interpreters/partial_eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2292,9 +2292,9 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals):
22922292
@weakref_lru_cache
22932293
def trace_to_jaxpr(
22942294
fun: Callable,
2295-
in_avals: FlatTree[AbstractValue | core.AvalQDD], # (args, kwargs) pair
2295+
in_avals: FlatTree, # (args, kwargs) pair
22962296
debug_info: core.DebugInfo
2297-
) -> tuple[ClosedJaxpr, PyTreeDef]:
2297+
) -> tuple[ClosedJaxpr, FlatTree]:
22982298
config.enable_checks.value and debug_info.assert_arg_names(len(in_avals))
22992299
parent_trace = core.trace_ctx.trace
23002300
trace = DynamicJaxprTrace(debug_info, parent_trace=parent_trace)

jax/_src/lax/control_flow/loops.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _stack(arrs: Sequence[Array], axis: int=0) -> Array:
8484

8585
def _promote_weak_typed_input(
8686
in_val:Any, in_aval:AbstractValue, out_aval:AbstractValue
87-
) -> tuple[AbstractValue, bool]:
87+
) -> tuple[Any, bool]:
8888
if getattr(in_aval, 'weak_type', False) and not core.typematch(in_aval, out_aval):
8989
new_dtype = dtypes.result_type(in_val, out_aval)
9090
return lax.convert_element_type(in_val, new_dtype), True
@@ -228,7 +228,7 @@ def scan(f, init, xs, length=None):
228228
return carry, stacked_y
229229

230230
if config.mutable_array_checks.value:
231-
check_no_aliased_ref_args(lambda: dbg_body, list(args), list(args_avals))
231+
check_no_aliased_ref_args(lambda: dbg_body, list(args_avals), list(args))
232232

233233
x_avals = xs_avals.map(lambda aval: core.mapped_aval(length, 0, aval))
234234
def _create_jaxpr(carry_avals):
@@ -252,6 +252,8 @@ def _create_jaxpr(carry_avals):
252252
if config.mutable_array_checks.value:
253253
_check_no_aliased_closed_over_refs(dbg_body, consts, list(args))
254254
carry_out_avals, ys_avals = out_avals.unpack()
255+
if len(carry_out_avals) != len(init_avals):
256+
_check_carry_type('scan body', f, init_avals, carry_out_avals)
255257
init, changed = init.map3(
256258
_promote_weak_typed_input,
257259
init_avals, carry_out_avals).unzip2()
@@ -277,7 +279,7 @@ def _create_jaxpr(carry_avals):
277279
if unroll < 0:
278280
raise ValueError("`unroll` must be a `bool` or a non-negative `int`.")
279281

280-
args_flat = (*init.vals, *xs.vals)
282+
args_flat = [*init.vals, *xs.vals]
281283

282284
# If the body forwards an input carry to an output carry, that input is
283285
# read-only and can be moved to be a const. Doing so can lead to efficiency
@@ -381,18 +383,18 @@ def _check_carry_type(name, body_fun, in_carry, out_carry):
381383
if p else 'the input carry')
382384
if in_carry.tree != out_carry.tree:
383385
try:
384-
out_carry = out_carry.unflatten()
386+
out_carry_unflat = out_carry.unflatten()
385387
except:
386-
out_carry = None
388+
out_carry_unflat = None
387389

388-
if out_carry is None:
390+
if out_carry_unflat is None:
389391
differences = (f'the input tree structure is:\n{in_carry.tree}\n' +
390392
f'the output tree structure is:\n{out_carry.tree}\n')
391393
else:
392394
diffs = [f'{component(path)} is a {thing1} but the corresponding component '
393395
f'of the carry output is a {thing2}, so {explanation}'
394396
for path, thing1, thing2, explanation
395-
in equality_errors(in_carry, out_carry)]
397+
in equality_errors(in_carry.unflatten(), out_carry.unflatten())]
396398
if len(diffs) == 0:
397399
return # the trees may have different aux data, but structures are same
398400
elif len(diffs) == 1:
@@ -1709,7 +1711,7 @@ def _create_jaxpr(init_avals):
17091711

17101712
cond_dbg = api_util.debug_info("while_cond", cond_fun, (init_val,), {})
17111713
body_dbg = api_util.debug_info("while_body", body_fun, (init_val,), {})
1712-
init_val = FlatTree.flatten(init_val)
1714+
init_val = FlatTree.flatten(init_val) # type: ignore
17131715
init_aval = init_val.map(core.get_aval)
17141716

17151717
# The body input and output avals must match exactly. However, we want to account for
@@ -1718,6 +1720,10 @@ def _create_jaxpr(init_avals):
17181720
# To do this, we compute the jaxpr in two passes: first with the raw inputs, and if
17191721
# necessary, a second time with modified init values.
17201722
cond_jaxpr, body_jaxpr, body_out_avals = _create_jaxpr(init_aval)
1723+
if len(body_out_avals) != len(init_aval):
1724+
_check_carry_type('while_loop body', body_fun, init_aval, body_out_avals)
1725+
assert False, "shouldn't get here"
1726+
17211727
init_val, changed = init_val.map3(
17221728
_promote_weak_typed_input,
17231729
init_aval, body_out_avals).unzip2()
@@ -1749,7 +1755,7 @@ def _create_jaxpr(init_avals):
17491755
_, keep_cond_carry = split_list(keep_cond, [len(cond_consts)])
17501756
move_to_const = _map(operator.not_, keep_cond_carry)
17511757

1752-
init_vals = list(init_val)
1758+
init_vals = list(init_val) # type: ignore
17531759
if any(move_to_const):
17541760
cond_jaxpr = pe.close_jaxpr(cond_jaxpr_)
17551761
body_jaxpr = pe.prune_closed_jaxpr_outputs(

jax/_src/tree_util.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,6 @@
3434
traceback_util.register_exclusion(__file__)
3535

3636
T = TypeVar("T")
37-
T1 = TypeVar("T1")
38-
T2 = TypeVar("T2")
39-
T3 = TypeVar("T3")
40-
T4 = TypeVar("T4")
4137
Typ = TypeVar("Typ", bound=type[Any])
4238
H = TypeVar("H", bound=Hashable)
4339

@@ -1357,20 +1353,18 @@ class FlatTree:
13571353
the tuple-returning function would change the tree structure and `unzip`
13581354
wouldn't be able to recover it.
13591355
"""
1360-
def __init__(self, vals:Sequence[T], treedef:PyTreeDef):
1356+
def __init__(self, vals:Sequence, treedef:PyTreeDef):
13611357
assert isinstance(treedef, pytree.PyTreeDef)
13621358
self.tree = treedef
1363-
self.vals = list(vals)
1359+
self.vals = tuple(vals)
13641360

1365-
def map(self, f:Callable[[T1], T2]) -> FlatTree[T2]:
1361+
def map(self, f:Callable) -> FlatTree:
13661362
ans_vals = []
13671363
for x in self.vals:
13681364
ans_vals.append(f(x))
13691365
return FlatTree(ans_vals, self.tree)
13701366

1371-
def map2(
1372-
self:FlatTree[T1], f:Callable[[T1, T2], T3],
1373-
t2:FlatTree[T2]) -> FlatTree[T3]:
1367+
def map2(self:FlatTree, f:Callable, t2:FlatTree) -> FlatTree:
13741368

13751369
n = len(self)
13761370
assert len(t2) == n
@@ -1380,19 +1374,18 @@ def map2(
13801374
return FlatTree(ans_vals, self.tree)
13811375

13821376
def map3(
1383-
self:FlatTree[T1], f:Callable[[T1, T2, T3], T4],
1384-
t2:FlatTree[T2], t3:FlatTree[T3]) -> FlatTree[T4]:
1377+
self:FlatTree, f:Callable, t2:FlatTree, t3:FlatTree) -> FlatTree:
13851378
n = len(self)
13861379
assert len(t2) == n and len(t3) == n
13871380
ans_vals = []
13881381
for x1, x2, x3 in zip(self.vals, t2.vals, t3.vals):
13891382
ans_vals.append(f(x1, x2, x3))
13901383
return FlatTree(ans_vals, self.tree)
13911384

1392-
def zip(self, t2:FlatTree[T2]) -> FlatTree[tuple[T1, T2]]:
1385+
def zip(self, t2:FlatTree) -> FlatTree:
13931386
assert False
13941387

1395-
def unzip2(self:FlatTree[tuple[T1, T2]]) -> tuple[FlatTree[T1], FlatTree[T2]]:
1388+
def unzip2(self:FlatTree) -> tuple[FlatTree, FlatTree]:
13961389
ys = []
13971390
zs = []
13981391
for y, z in self.vals:
@@ -1425,7 +1418,7 @@ def pack(tree):
14251418
else:
14261419
assert False
14271420

1428-
def unpack(self:FlatTree[tuple]) -> tuple[FlatTree]:
1421+
def unpack(self:FlatTree) -> tuple[FlatTree, ...]:
14291422
# TODO: this is O(N) not O(1) (with N as the number of leaves). If it
14301423
# becomes a problem we can fix it with a fancier data tree.
14311424
trees = treedef_children(self.tree)
@@ -1444,11 +1437,19 @@ def flatten(tree: PyTree) -> FlatTree:
14441437
def unflatten(self) -> PyTree:
14451438
return tree_unflatten(self.tree, self.vals)
14461439

1447-
def update_from_list(self, new_vals:list[T1]) -> FlatTree[T1]:
1440+
def update_from_list(self, new_vals:list) -> FlatTree:
14481441
return FlatTree(new_vals, self.tree)
14491442

14501443
def __len__(self):
14511444
return self.tree.num_leaves
14521445

14531446
def __iter__(self):
14541447
return self.vals.__iter__()
1448+
1449+
def __eq__(self, other):
1450+
return (isinstance(other, FlatTree)
1451+
and self.vals == other.vals
1452+
and self.tree == other.tree)
1453+
1454+
def __hash__(self):
1455+
return hash((self.vals, self.tree))

0 commit comments

Comments
 (0)