2323import argparse
2424import re
2525import sys
26- import uuid
26+ import os
2727
2828# C type: const int
2929ERROR_CLASSES = [
@@ -1042,6 +1042,30 @@ def need_bigcount(self):
10421042 return any ('COUNT' in param .type_ for param in self .params )
10431043
10441044
1045+ class TemplateParseError (Exception ):
1046+ """Error raised during parsing."""
1047+ pass
1048+
1049+
1050+ def validate_body (body ):
1051+ """Validate the body of a template."""
1052+ # Just do a simple bracket balance test determine the bounds of the
1053+ # function body. All lines after the function body should be blank. There
1054+ # are cases where this will break, such as if someone puts code all on one
1055+ # line.
1056+ bracket_balance = 0
1057+ line_count = 0
1058+ for line in body :
1059+ line = line .strip ()
1060+ if bracket_balance == 0 and line_count > 0 and line :
1061+ raise TemplateParserError ('Extra code found in template; only one function body is allowed' )
1062+
1063+ update = line .count ('{' ) - line .count ('}' )
1064+ bracket_balance += update
1065+ if bracket_balance != 0 :
1066+ line_count += 1
1067+
1068+
10451069class SourceTemplate :
10461070 """Source template for a single API function."""
10471071
@@ -1051,8 +1075,10 @@ def __init__(self, prototype, header, body):
10511075 self .body = body
10521076
10531077 @staticmethod
1054- def load (fname ):
1078+ def load (fname , prefix = None ):
10551079 """Load a template file and return the SourceTemplate."""
1080+ if prefix is not None :
1081+ fname = os .path .join (prefix , fname )
10561082 with open (fname ) as fp :
10571083 header = []
10581084 prototype = []
@@ -1061,11 +1087,12 @@ def load(fname):
10611087 for line in fp :
10621088 line = line .rstrip ()
10631089 if prototype and line .startswith ('PROTOTYPE' ):
1064- raise RuntimeError ('more than one prototype found in template file' )
1090+ raise TemplateParseError ('more than one prototype found in template file' )
10651091 elif ((prototype and not any (')' in s for s in prototype ))
10661092 or line .startswith ('PROTOTYPE' )):
10671093 prototype .append (line )
10681094 elif prototype :
1095+ # Validate bracket balance
10691096 body .append (line )
10701097 else :
10711098 header .append (line )
@@ -1082,6 +1109,8 @@ def load(fname):
10821109 params = [param .strip () for param in prototype [i + 1 :j ].split (',' ) if param .strip ()]
10831110 params = [Parameter (param ) for param in params ]
10841111 prototype = Prototype (name , return_type , params )
1112+ # Ensure the body contains only one function
1113+ validate_body (body )
10851114 return SourceTemplate (prototype , header , body )
10861115
10871116 def print_header (self , file = sys .stdout ):
@@ -1148,7 +1177,7 @@ def standard_abi(base_name, template):
11481177 print (f'#include "{ ABI_INTERNAL_HEADER } "' )
11491178
11501179 # Static internal function (add a random component to avoid conflicts)
1151- internal_name = f'ompi_ { template .prototype .name } _ { uuid . uuid4 (). hex [: 10 ] } '
1180+ internal_name = f'ompi_abi_ { template .prototype .name } '
11521181 internal_sig = template .prototype .signature ('ompi' , internal_name ,
11531182 count_type = 'MPI_Count' )
11541183 print ('static inline' , internal_sig )
@@ -1190,7 +1219,7 @@ def generate_function(prototype, fn_name, internal_fn, count_type='int'):
11901219
11911220def gen_header (args ):
11921221 """Generate an ABI header and conversion code."""
1193- prototypes = [SourceTemplate .load (file_ ).prototype for file_ in args .file ]
1222+ prototypes = [SourceTemplate .load (file_ , args . srcdir ).prototype for file_ in args .file ]
11941223
11951224 builder = ABIHeaderBuilder (prototypes , external = args .external )
11961225 builder .dump_header ()
@@ -1219,6 +1248,7 @@ def main():
12191248 parser_header = subparsers .add_parser ('header' )
12201249 parser_header .add_argument ('file' , nargs = '+' , help = 'list of template source files' )
12211250 parser_header .add_argument ('--external' , action = 'store_true' , help = 'generate external mpi.h header file' )
1251+ parser_header .add_argument ('--srcdir' , help = 'source directory' )
12221252 parser_header .set_defaults (func = gen_header )
12231253
12241254 parser_gen = subparsers .add_parser ('source' )
0 commit comments