1313import numpy as np
1414
1515import pymbolic .primitives as pmbl
16+ import pymbolic .mapper as pmbl_mapper
1617
1718from loki import (
1819 Sourcefile , Subroutine , Module , Scope , BasicType ,
@@ -2109,7 +2110,7 @@ def test_expression_parser_evaluate(case):
21092110 context = {'a' : 6 }
21102111 test_str = '1 + 1 + a + some_func(a, 10)'
21112112 parsed = parse_expr (convert_to_case (f'{ test_str } ' , mode = case ), evaluate = True , context = context )
2112- with pytest .raises (Exception ):
2113+ with pytest .raises (pmbl_mapper . evaluator . UnknownVariableError ):
21132114 parsed = parse_expr (convert_to_case (f'{ test_str } ' , mode = case ), evaluate = True , strict = True , context = context )
21142115 parsed = parse_expr (convert_to_case (f'{ test_str } ' , mode = case ), evaluate = True , strict = False , context = context )
21152116 assert str (parsed ).lower ().replace (' ' , '' ) == '8+some_func(6,10)'
@@ -2134,7 +2135,23 @@ def some_func(a, b, c=None):
21342135 parsed = parse_expr (convert_to_case (f'{ test_str } ' , mode = case ), evaluate = True , context = context )
21352136 assert str (parsed ).lower ().replace (' ' , '' ) == '(13+c+1)/(c+1)'
21362137
2138+ class BarBarBar :
2139+ val_barbarbar = 5
2140+
2141+ class BarBar :
2142+ barbarbar = BarBarBar ()
2143+ val_barbar = - 3
2144+ def barbar_func (self , a ):
2145+ return a - 1
2146+
2147+ class Bar :
2148+ barbar = BarBar ()
2149+ val_bar = 5
2150+ def bar_func (self , a ):
2151+ return a ** 2
2152+
21372153 class Foo :
2154+ bar = Bar () # pylint: disable=disallowed-name
21382155 val3 = 1
21392156 arr = [[1 , 2 ], [3 , 4 ]]
21402157 def __init__ (self , _val1 , _val2 ):
@@ -2150,18 +2167,33 @@ def static_func(a):
21502167 test_str = 'foo%val1 + foo%val2 + foo%val3'
21512168 parsed = parse_expr (convert_to_case (f'{ test_str } ' , mode = case ))
21522169 assert str (parsed ).lower ().replace (' ' , '' ) == 'foo%val1+foo%val2+foo%val3'
2153- with pytest .raises (Exception ):
2170+ with pytest .raises (pmbl_mapper . evaluator . UnknownVariableError ):
21542171 parsed = parse_expr (convert_to_case (f'{ test_str } ' , mode = case ), evaluate = True , strict = True )
21552172 parsed = parse_expr (convert_to_case (f'{ test_str } ' , mode = case ), evaluate = True , context = context )
21562173 assert parsed == 6
21572174 test_str = 'foo%val1 + foo%some_func(1, 2) + foo%static_func_2(3)'
21582175 parsed = parse_expr (convert_to_case (f'{ test_str } ' , mode = case ), evaluate = True , context = context )
21592176 assert str (parsed ).lower ().replace (' ' , '' ) == '5+foo%static_func_2(3)'
2160- with pytest .raises (Exception ):
2177+ with pytest .raises (pmbl_mapper . evaluator . UnknownVariableError ):
21612178 parsed = parse_expr (convert_to_case (f'{ test_str } ' , mode = case ), evaluate = True , strict = True )
21622179 test_str = 'foo%val1 + foo%some_func(1, 2) + foo%static_func(3) + foo%arr(1, 2)'
21632180 parsed = parse_expr (convert_to_case (f'{ test_str } ' , mode = case ), evaluate = True , context = context , strict = True )
21642181 assert parsed == 13
21652182 test_str = 'foo%val1 + foo%some_func(1, b=2) + foo%static_func(a=3) + foo%arr(1, 2)'
21662183 parsed = parse_expr (convert_to_case (f'{ test_str } ' , mode = case ), evaluate = True , context = context , strict = True )
21672184 assert parsed == 13
2185+ test_str = 'foo%bar%val_bar + 1'
2186+ parsed = parse_expr (convert_to_case (f'{ test_str } ' , mode = case ), evaluate = True , context = context )
2187+ assert parsed == 6
2188+ test_str = 'foo%bar%bar_func(2) + 1'
2189+ parsed = parse_expr (convert_to_case (f'{ test_str } ' , mode = case ), evaluate = True , context = context )
2190+ assert parsed == 5
2191+ test_str = 'foo%bar%barbar%val_barbar + 1'
2192+ parsed = parse_expr (convert_to_case (f'{ test_str } ' , mode = case ), evaluate = True , context = context )
2193+ assert parsed == - 2
2194+ test_str = 'foo%bar%barbar%barbar_func(0) + 1'
2195+ parsed = parse_expr (convert_to_case (f'{ test_str } ' , mode = case ), evaluate = True , context = context )
2196+ assert parsed == 0
2197+ test_str = 'foo%bar%barbar%barbarbar%val_barbarbar + 1'
2198+ parsed = parse_expr (convert_to_case (f'{ test_str } ' , mode = case ), evaluate = True , context = context )
2199+ assert parsed == 6
0 commit comments