2323
2424# number of special values to insert when a configuration cannot be measured
2525
26+
2627class ErrorConfig (str ):
27- def __str__ (self ): return self .__class__ .__name__
28- def __repr__ (self ): return self .__class__ .__name__
28+
29+ def __str__ (self ):
30+ return self .__class__ .__name__
31+
32+ def __repr__ (self ):
33+ return self .__class__ .__name__
34+
2935
3036class InvalidConfig (ErrorConfig ):
3137 pass
3238
39+
3340class CompilationFailedConfig (ErrorConfig ):
3441 pass
3542
43+
3644class RuntimeFailedConfig (ErrorConfig ):
3745 pass
3846
@@ -46,6 +54,7 @@ def __init__(self):
4654class SkippableFailure (Exception ):
4755 """Exception used to raise when compiling or launching a kernel fails for a reason that can be expected"""
4856
57+
4958class StopCriterionReached (Exception ):
5059 """Exception thrown when a stop criterion has been reached"""
5160
@@ -122,7 +131,7 @@ def check_stop_criterion(to):
122131 """ checks if max_fevals is reached or time limit is exceeded """
123132 if "max_fevals" in to and len (to .unique_results ) >= to .max_fevals :
124133 raise StopCriterionReached ("max_fevals reached" )
125- if "time_limit" in to and (((time .perf_counter () - to .start_time ) + (to .simulated_time * 1e-3 )) > to .time_limit ):
134+ if "time_limit" in to and (((time .perf_counter () - to .start_time ) + (to .simulated_time * 1e-3 )) > to .time_limit ):
126135 raise StopCriterionReached ("time limit exceeded" )
127136
128137
@@ -663,6 +672,7 @@ def normalize_verify_function(v):
663672
664673 Undefined behaviour if the passed function does not match the required signatures.
665674 """
675+
666676 # python 3.3+
667677 def has_kw_argument (func , name ):
668678 sig = signature (func )
@@ -681,12 +691,14 @@ def parse_restrictions(restrictions: list, tune_params: dict):
681691
682692 # rewrite the restrictions so variables are singled out
683693 regex_match_variable = r"([a-zA-Z_$][a-zA-Z_$0-9]*)"
694+
684695 def replace_params (match_object ):
685696 key = match_object .group (1 )
686697 if key in tune_params :
687698 return 'params["' + key + '"]'
688699 else :
689700 return key
701+
690702 parsed = ") and (" .join ([re .sub (regex_match_variable , replace_params , res ) for res in restrictions ])
691703
692704 # tidy up the code by removing the last suffix and unnecessary spaces
@@ -709,6 +721,7 @@ def compile_restrictions(restrictions: list, tune_params: dict):
709721
710722
711723class NpEncoder (json .JSONEncoder ):
724+
712725 def default (self , obj ):
713726 if isinstance (obj , np .integer ):
714727 return int (obj )
@@ -794,7 +807,11 @@ def process_cache(cache, kernel_options, tuning_options, runner):
794807 elif not all ([i == j for i , j in zip (cached_data ["problem_size" ], kernel_options .problem_size )]):
795808 raise ValueError ("Cannot load cache which contains results for different problem_size" )
796809 if cached_data ["tune_params_keys" ] != list (tuning_options .tune_params .keys ()):
797- raise ValueError ("Cannot load cache which contains results obtained with different tunable parameters" )
810+ if all (key in tuning_options .tune_params for key in cached_data ["tune_params_keys" ]):
811+ raise ValueError (f"All tunable parameters are present, but the order is wrong. \
812+ Cache has order: { cached_data ['tune_params_keys' ]} , tuning_options has: { list (tuning_options .tune_params .keys ())} " )
813+ raise ValueError (f"Cannot load cache which contains results obtained with different tunable parameters. \
814+ Cache has: { cached_data ['tune_params_keys' ]} , tuning_options has: { list (tuning_options .tune_params .keys ())} " )
798815
799816 tuning_options .cachefile = cache
800817 tuning_options .cache = cached_data ["cache" ]
@@ -817,9 +834,11 @@ def read_cache(cache, open_cache=True):
817834 with open (cache , "w" ) as cachefile :
818835 cachefile .write (filestr [:- 3 ] + "," )
819836
820- error_configs = {"InvalidConfig" : InvalidConfig (),
821- "CompilationFailedConfig" : CompilationFailedConfig (),
822- "RuntimeFailedConfig" : RuntimeFailedConfig ()}
837+ error_configs = {
838+ "InvalidConfig" : InvalidConfig (),
839+ "CompilationFailedConfig" : CompilationFailedConfig (),
840+ "RuntimeFailedConfig" : RuntimeFailedConfig ()
841+ }
823842
824843 # replace strings with ErrorConfig instances
825844 cache_data = json .loads (filestr )
@@ -864,7 +883,7 @@ def JSONconverter(obj):
864883
865884 # Convert ErrorConfig objects to string, wanted to do this inside the JSONconverter but couldn't get it to work
866885 output_params = params .copy ()
867- for k ,v in output_params .items ():
886+ for k , v in output_params .items ():
868887 if isinstance (v , ErrorConfig ):
869888 output_params [k ] = str (v )
870889
0 commit comments