Skip to content

Commit 285f266

Browse files
committed
add tests for comparison folding
1 parent 713fde5 commit 285f266

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

Lib/test/test_ast/test_ast.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3240,6 +3240,53 @@ def test_folding_tuple(self):
32403240

32413241
self.assert_ast(code, non_optimized_target, optimized_target)
32423242

3243+
def test_folding_compare(self):
3244+
true = self.wrap_expr(ast.Constant(value=True))
3245+
false = self.wrap_expr(ast.Constant(value=False))
3246+
3247+
folded_cases = (
3248+
("3 > 2 > 1", (ast.Constant(3), [ast.Gt(), ast.Gt()], [ast.Constant(value=2), ast.Constant(value=1)]), true),
3249+
("3 > 4 > 1", (ast.Constant(3), [ast.Gt(), ast.Gt()], [ast.Constant(value=4), ast.Constant(value=1)]), false),
3250+
("3 >= 3 >= 1", (ast.Constant(3), [ast.GtE(), ast.GtE()], [ast.Constant(value=3), ast.Constant(value=1)]), true),
3251+
("3 >= 4 >= 1", (ast.Constant(3), [ast.GtE(), ast.GtE()], [ast.Constant(value=4), ast.Constant(value=1)]), false),
3252+
("1 < 2 < 3", (ast.Constant(1), [ast.Lt(), ast.Lt()], [ast.Constant(value=2), ast.Constant(value=3)]), true),
3253+
("1 < 0 < 3", (ast.Constant(1), [ast.Lt(), ast.Lt()], [ast.Constant(value=0), ast.Constant(value=3)]), false),
3254+
("1 <= 2 <= 3", (ast.Constant(1), [ast.LtE(), ast.LtE()], [ast.Constant(value=2), ast.Constant(value=3)]), true),
3255+
("1 <= 0 <= 3", (ast.Constant(1), [ast.LtE(), ast.LtE()], [ast.Constant(value=0), ast.Constant(value=3)]), false),
3256+
("1 == 1.0 == True", (ast.Constant(1), [ast.Eq(), ast.Eq()], [ast.Constant(value=1.0), ast.Constant(value=True)]), true),
3257+
("1 == 2 == True", (ast.Constant(1), [ast.Eq(), ast.Eq()], [ast.Constant(value=2), ast.Constant(value=True)]), false),
3258+
("1 != 2 != 3", (ast.Constant(1), [ast.NotEq(), ast.NotEq()], [ast.Constant(value=2), ast.Constant(value=3)]), true),
3259+
("1 != 1 != 3", (ast.Constant(1), [ast.NotEq(), ast.NotEq()], [ast.Constant(value=1), ast.Constant(value=3)]), false),
3260+
("1 in [1, 2]", (ast.Constant(1), [ast.In()], [ast.List(elts=[ast.Constant(1), ast.Constant(2)])]), true),
3261+
("1 in [2, 2]", (ast.Constant(1), [ast.In()], [ast.List(elts=[ast.Constant(2), ast.Constant(2)])]), false),
3262+
("1 not in [1, 2]", (ast.Constant(1), [ast.NotIn()], [ast.List(elts=[ast.Constant(1), ast.Constant(2)])]), false),
3263+
("1 not in [2, 2]", (ast.Constant(1), [ast.NotIn()], [ast.List(elts=[ast.Constant(2), ast.Constant(2)])]), true),
3264+
)
3265+
3266+
for code, original, folded in folded_cases:
3267+
left, ops, comparators = original
3268+
unfolded = self.wrap_expr(ast.Compare(left=left, ops=ops, comparators=comparators))
3269+
self.assert_ast(code=code, non_optimized_target=unfolded, optimized_target=folded)
3270+
3271+
# these should stay as they were
3272+
unfolded_cases = (
3273+
("3 > 2 > []", ast.Compare(left=ast.Constant(3), ops=[ast.Gt(), ast.Gt()], comparators=[ast.Constant(2), ast.List()])),
3274+
("1 > [] > 0", ast.Compare(left=ast.Constant(1), ops=[ast.Gt(), ast.Gt()], comparators=[ast.List(), ast.Constant(0)])),
3275+
("1 >= [] >= 0", ast.Compare(left=ast.Constant(1), ops=[ast.GtE(), ast.GtE()], comparators=[ast.List(), ast.Constant(0)])),
3276+
("1 < [] < 0", ast.Compare(left=ast.Constant(1), ops=[ast.Lt(), ast.Lt()], comparators=[ast.List(), ast.Constant(0)])),
3277+
("1 <= [] <= 0", ast.Compare(left=ast.Constant(1), ops=[ast.LtE(), ast.LtE()], comparators=[ast.List(), ast.Constant(0)])),
3278+
("1 == [] == 0", ast.Compare(left=ast.Constant(1), ops=[ast.Eq(), ast.Eq()], comparators=[ast.List(), ast.Constant(0)])),
3279+
("1 != [] != 0", ast.Compare(left=ast.Constant(1), ops=[ast.NotEq(), ast.NotEq()], comparators=[ast.List(), ast.Constant(0)])),
3280+
("1 is 1", ast.Compare(left=ast.Constant(1), ops=[ast.Is()], comparators=[ast.Constant(1)])),
3281+
("1 is not 1", ast.Compare(left=ast.Constant(1), ops=[ast.IsNot()], comparators=[ast.Constant(1)])),
3282+
# invalid also should stay as they were
3283+
("1 in 1", ast.Compare(left=ast.Constant(1), ops=[ast.In()], comparators=[ast.Constant(1)])),
3284+
("1 not in 1", ast.Compare(left=ast.Constant(1), ops=[ast.NotIn()], comparators=[ast.Constant(1)])),
3285+
)
3286+
3287+
for code, expected in unfolded_cases:
3288+
self.assertTrue(ast.compare(ast.parse(code), self.wrap_expr(expected)))
3289+
32433290
def test_folding_comparator_list_set_subst(self):
32443291
"""Test substitution of list/set with tuple/frozenset in expressions like "1 in [1]" or "1 in {1}" """
32453292

0 commit comments

Comments
 (0)