@@ -3240,6 +3240,53 @@ def test_folding_tuple(self):
32403240
32413241 self .assert_ast (code , non_optimized_target , optimized_target )
32423242
3243+ def test_folding_compare (self ):
3244+ true = self .wrap_expr (ast .Constant (value = True ))
3245+ false = self .wrap_expr (ast .Constant (value = False ))
3246+
3247+ folded_cases = (
3248+ ("3 > 2 > 1" , (ast .Constant (3 ), [ast .Gt (), ast .Gt ()], [ast .Constant (value = 2 ), ast .Constant (value = 1 )]), true ),
3249+ ("3 > 4 > 1" , (ast .Constant (3 ), [ast .Gt (), ast .Gt ()], [ast .Constant (value = 4 ), ast .Constant (value = 1 )]), false ),
3250+ ("3 >= 3 >= 1" , (ast .Constant (3 ), [ast .GtE (), ast .GtE ()], [ast .Constant (value = 3 ), ast .Constant (value = 1 )]), true ),
3251+ ("3 >= 4 >= 1" , (ast .Constant (3 ), [ast .GtE (), ast .GtE ()], [ast .Constant (value = 4 ), ast .Constant (value = 1 )]), false ),
3252+ ("1 < 2 < 3" , (ast .Constant (1 ), [ast .Lt (), ast .Lt ()], [ast .Constant (value = 2 ), ast .Constant (value = 3 )]), true ),
3253+ ("1 < 0 < 3" , (ast .Constant (1 ), [ast .Lt (), ast .Lt ()], [ast .Constant (value = 0 ), ast .Constant (value = 3 )]), false ),
3254+ ("1 <= 2 <= 3" , (ast .Constant (1 ), [ast .LtE (), ast .LtE ()], [ast .Constant (value = 2 ), ast .Constant (value = 3 )]), true ),
3255+ ("1 <= 0 <= 3" , (ast .Constant (1 ), [ast .LtE (), ast .LtE ()], [ast .Constant (value = 0 ), ast .Constant (value = 3 )]), false ),
3256+ ("1 == 1.0 == True" , (ast .Constant (1 ), [ast .Eq (), ast .Eq ()], [ast .Constant (value = 1.0 ), ast .Constant (value = True )]), true ),
3257+ ("1 == 2 == True" , (ast .Constant (1 ), [ast .Eq (), ast .Eq ()], [ast .Constant (value = 2 ), ast .Constant (value = True )]), false ),
3258+ ("1 != 2 != 3" , (ast .Constant (1 ), [ast .NotEq (), ast .NotEq ()], [ast .Constant (value = 2 ), ast .Constant (value = 3 )]), true ),
3259+ ("1 != 1 != 3" , (ast .Constant (1 ), [ast .NotEq (), ast .NotEq ()], [ast .Constant (value = 1 ), ast .Constant (value = 3 )]), false ),
3260+ ("1 in [1, 2]" , (ast .Constant (1 ), [ast .In ()], [ast .List (elts = [ast .Constant (1 ), ast .Constant (2 )])]), true ),
3261+ ("1 in [2, 2]" , (ast .Constant (1 ), [ast .In ()], [ast .List (elts = [ast .Constant (2 ), ast .Constant (2 )])]), false ),
3262+ ("1 not in [1, 2]" , (ast .Constant (1 ), [ast .NotIn ()], [ast .List (elts = [ast .Constant (1 ), ast .Constant (2 )])]), false ),
3263+ ("1 not in [2, 2]" , (ast .Constant (1 ), [ast .NotIn ()], [ast .List (elts = [ast .Constant (2 ), ast .Constant (2 )])]), true ),
3264+ )
3265+
3266+ for code , original , folded in folded_cases :
3267+ left , ops , comparators = original
3268+ unfolded = self .wrap_expr (ast .Compare (left = left , ops = ops , comparators = comparators ))
3269+ self .assert_ast (code = code , non_optimized_target = unfolded , optimized_target = folded )
3270+
3271+ # these should stay as they were
3272+ unfolded_cases = (
3273+ ("3 > 2 > []" , ast .Compare (left = ast .Constant (3 ), ops = [ast .Gt (), ast .Gt ()], comparators = [ast .Constant (2 ), ast .List ()])),
3274+ ("1 > [] > 0" , ast .Compare (left = ast .Constant (1 ), ops = [ast .Gt (), ast .Gt ()], comparators = [ast .List (), ast .Constant (0 )])),
3275+ ("1 >= [] >= 0" , ast .Compare (left = ast .Constant (1 ), ops = [ast .GtE (), ast .GtE ()], comparators = [ast .List (), ast .Constant (0 )])),
3276+ ("1 < [] < 0" , ast .Compare (left = ast .Constant (1 ), ops = [ast .Lt (), ast .Lt ()], comparators = [ast .List (), ast .Constant (0 )])),
3277+ ("1 <= [] <= 0" , ast .Compare (left = ast .Constant (1 ), ops = [ast .LtE (), ast .LtE ()], comparators = [ast .List (), ast .Constant (0 )])),
3278+ ("1 == [] == 0" , ast .Compare (left = ast .Constant (1 ), ops = [ast .Eq (), ast .Eq ()], comparators = [ast .List (), ast .Constant (0 )])),
3279+ ("1 != [] != 0" , ast .Compare (left = ast .Constant (1 ), ops = [ast .NotEq (), ast .NotEq ()], comparators = [ast .List (), ast .Constant (0 )])),
3280+ ("1 is 1" , ast .Compare (left = ast .Constant (1 ), ops = [ast .Is ()], comparators = [ast .Constant (1 )])),
3281+ ("1 is not 1" , ast .Compare (left = ast .Constant (1 ), ops = [ast .IsNot ()], comparators = [ast .Constant (1 )])),
3282+ # invalid also should stay as they were
3283+ ("1 in 1" , ast .Compare (left = ast .Constant (1 ), ops = [ast .In ()], comparators = [ast .Constant (1 )])),
3284+ ("1 not in 1" , ast .Compare (left = ast .Constant (1 ), ops = [ast .NotIn ()], comparators = [ast .Constant (1 )])),
3285+ )
3286+
3287+ for code , expected in unfolded_cases :
3288+ self .assertTrue (ast .compare (ast .parse (code ), self .wrap_expr (expected )))
3289+
32433290 def test_folding_comparator_list_set_subst (self ):
32443291 """Test substitution of list/set with tuple/frozenset in expressions like "1 in [1]" or "1 in {1}" """
32453292
0 commit comments