Skip to content

Commit c834d5d

Browse files
committed
Adapted 'get pragma parameters' utilities, switching from pure REGEX to scanning the string and breaking down the chunks to match parentheses
1 parent a60479e commit c834d5d

File tree

2 files changed

+65
-27
lines changed

2 files changed

+65
-27
lines changed

loki/ir/pragma_utils.py

Lines changed: 62 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,62 @@ def is_loki_pragma(pragma, starts_with=None):
5050
return True
5151

5252

53-
_get_pragma_parameters_re = re.compile(r'(?P<command>[\w-]+)\s*(?:\((?P<arg>.+?)\))?')
54-
"""
55-
Regular expression pattern to match pragma parameters.
56-
57-
E.g., match ``!$loki something key1(val1) key2(val2)``.
58-
Problematic for e.g., ``!$loki something key1((val1 + 1)/2)``,
59-
use instead: `_get_pragma_dim_parameter_re`.
60-
"""
61-
62-
def get_pragma_parameters(pragma, starts_with=None, only_loki_pragmas=True,
63-
pattern=_get_pragma_parameters_re):
53+
class PragmaParameters:
54+
55+
_pattern_opening_parenthesis = re.compile(r'\(')
56+
_pattern_closing_parenthesis = re.compile(r'\)')
57+
_pattern_quoted_string = re.compile(r'(?:\'.*?\')|(?:".*?")')
58+
59+
@classmethod
60+
def find(cls, string):
61+
string = cls._pattern_quoted_string.sub('', string)
62+
p_open = [match.start() for match in cls._pattern_opening_parenthesis.finditer(string)]
63+
p_close = [match.start() for match in cls._pattern_closing_parenthesis.finditer(string)]
64+
assert len(p_open) == len(p_close)
65+
66+
def _match_spans(open_, close_):
67+
# We match pairs of parentheses starting at the end by pushing and popping from a stack.
68+
# Whenever the stack runs out, we have fully resolved a set of (nested) parenthesis and
69+
# record the corresponding span
70+
if not close_:
71+
return []
72+
spans = []
73+
stack = [close_.pop()]
74+
while open_:
75+
if not close_ or open_[-1] > close_[-1]:
76+
assert stack
77+
start = open_.pop()
78+
end = stack.pop()
79+
if not stack:
80+
spans.append((start, end))
81+
else:
82+
stack.append(close_.pop())
83+
assert not (stack or open_ or close_)
84+
return spans
85+
86+
p_spans = _match_spans(p_open, p_close)
87+
spans = []
88+
while p_spans:
89+
spans.append(p_spans.pop())
90+
if p_spans:
91+
spans += p_spans[::-1]
92+
parameters = defaultdict(list)
93+
if not spans and string.strip():
94+
for key in string.strip().split(' '): # keys[:-1]:
95+
if key != '':
96+
parameters[key].append(None)
97+
for i, span in enumerate(spans):
98+
keys = string[spans[i-1][1]+1 if i>=1 else 0:span[0]].strip().split(' ')
99+
if len(keys) > 1:
100+
for key in keys[:-1]:
101+
if key != '':
102+
parameters[key].append(None)
103+
parameters[keys[-1]].append(string[span[0]+1:span[1]])
104+
parameters = {k: v if len(v) > 1 else v[0] for k, v in parameters.items()}
105+
return parameters
106+
107+
108+
def get_pragma_parameters(pragma, starts_with=None, only_loki_pragmas=True):
64109
"""
65110
Parse the pragma content for parameters in the form ``<command>[(<arg>)]`` and
66111
return them as a map ``{<command>: <arg> or None}``.
@@ -78,40 +123,32 @@ def get_pragma_parameters(pragma, starts_with=None, only_loki_pragmas=True,
78123
the keyword the pragma content should start with.
79124
only_loki_pragmas : bool, optional
80125
restrict parameter extraction to ``loki`` pragmas only.
81-
pattern : :any:`regex.Pattern`, optional
82-
Regex pattern (default: `_get_pragma_dim_parameter_re`).
83126
84127
Returns
85128
-------
86129
dict :
87130
Mapping of parameters ``{<command>: <arg> or <None>}`` with the values being a list
88131
when multiple entries have the same key
89132
"""
133+
pragma_parameters = PragmaParameters()
90134
pragma = as_tuple(pragma)
91135
parameters = defaultdict(list)
92136
for p in pragma:
137+
parameter = None
93138
if only_loki_pragmas and p.keyword.lower() != 'loki':
94139
continue
95140
content = p.content or ''
96141
if starts_with is not None:
97142
if not content.lower().startswith(starts_with.lower()):
98143
continue
99144
content = content[len(starts_with):]
100-
for match in re.finditer(pattern, content):
101-
parameters[match.group('command')].append(match.group('arg'))
145+
parameter = pragma_parameters.find(content)
146+
for key in parameter:
147+
parameters[key].append(parameter[key])
102148
parameters = {k: v if len(v) > 1 else v[0] for k, v in parameters.items()}
103149
return parameters
104150

105151

106-
_get_pragma_dim_parameter_re = re.compile(r'(?P<command>[\w-]+)\s*(?:\((?P<arg>.*)\))?')
107-
"""
108-
Regular expression pattern to match pragma dimension parameter.
109-
110-
E.g., match ``!$loki something key1((val1 + 1)/2)``.
111-
Problematic for e.g., ``!$loki something key1(val1) key2(val2)``,
112-
use instead: `_get_pragma_parameters_re`.
113-
"""
114-
115152
def process_dimension_pragmas(ir, scope=None):
116153
"""
117154
Process any ``!$loki dimension`` pragmas to override deferred dimensions
@@ -130,8 +167,7 @@ def process_dimension_pragmas(ir, scope=None):
130167
if is_loki_pragma(decl.pragma, starts_with='dimension'):
131168
for v in decl.symbols:
132169
# Found dimension override for variable
133-
dims = get_pragma_parameters(decl.pragma,
134-
pattern=_get_pragma_dim_parameter_re)['dimension']
170+
dims = get_pragma_parameters(decl.pragma)['dimension']
135171
dims = [d.strip() for d in dims.split(',')]
136172
# parse each dimension
137173
shape = tuple(parse_expr(d, scope=scope) for d in dims)

loki/ir/tests/test_pragma_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ def test_is_loki_pragma(keyword, content, starts_with, ref):
5050
('dataflow group(1) group(2)', 'dataflow', {'group': ['1', '2']}),
5151
('foo bar(^£!$%*[]:@+-_=~#/?.,<>;) baz foobar(abc_123")', 'foo',
5252
{'bar':'^£!$%*[]:@+-_=~#/?.,<>;', 'baz': None, 'foobar': 'abc_123"'}),
53-
('target map(a) map(to: b) map(from: c)', None, {'target': None, 'map': ['a', 'to: b', 'from: c']})
53+
('target map(a) map(to: b) map(from: c)', None, {'target': None, 'map': ['a', 'to: b', 'from: c']}),
54+
('arg1(val1) arg2(val2/val3) arg3((val1 + val2)/(val3))', None, {'arg1': 'val1',
55+
'arg2': 'val2/val3', 'arg3': '(val1 + val2)/(val3)'})
5456
])
5557
def test_get_pragma_parameters(content, starts_with, ref):
5658
"""

0 commit comments

Comments
 (0)