Skip to content

Commit de79635

Browse files
committed
More specific error message when passing a different set of tunable parameters than what is in the cache
1 parent e71b8f6 commit de79635

File tree

1 file changed

+27
-8
lines changed

1 file changed

+27
-8
lines changed

kernel_tuner/util.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,24 @@
2323

2424
# number of special values to insert when a configuration cannot be measured
2525

26+
2627
class ErrorConfig(str):
27-
def __str__(self): return self.__class__.__name__
28-
def __repr__(self): return self.__class__.__name__
28+
29+
def __str__(self):
30+
return self.__class__.__name__
31+
32+
def __repr__(self):
33+
return self.__class__.__name__
34+
2935

3036
class InvalidConfig(ErrorConfig):
3137
pass
3238

39+
3340
class CompilationFailedConfig(ErrorConfig):
3441
pass
3542

43+
3644
class RuntimeFailedConfig(ErrorConfig):
3745
pass
3846

@@ -46,6 +54,7 @@ def __init__(self):
4654
class SkippableFailure(Exception):
4755
"""Exception used to raise when compiling or launching a kernel fails for a reason that can be expected"""
4856

57+
4958
class StopCriterionReached(Exception):
5059
"""Exception thrown when a stop criterion has been reached"""
5160

@@ -122,7 +131,7 @@ def check_stop_criterion(to):
122131
""" checks if max_fevals is reached or time limit is exceeded """
123132
if "max_fevals" in to and len(to.unique_results) >= to.max_fevals:
124133
raise StopCriterionReached("max_fevals reached")
125-
if "time_limit" in to and (((time.perf_counter() - to.start_time) + (to.simulated_time*1e-3)) > to.time_limit):
134+
if "time_limit" in to and (((time.perf_counter() - to.start_time) + (to.simulated_time * 1e-3)) > to.time_limit):
126135
raise StopCriterionReached("time limit exceeded")
127136

128137

@@ -663,6 +672,7 @@ def normalize_verify_function(v):
663672
664673
Undefined behaviour if the passed function does not match the required signatures.
665674
"""
675+
666676
# python 3.3+
667677
def has_kw_argument(func, name):
668678
sig = signature(func)
@@ -681,12 +691,14 @@ def parse_restrictions(restrictions: list, tune_params: dict):
681691

682692
# rewrite the restrictions so variables are singled out
683693
regex_match_variable = r"([a-zA-Z_$][a-zA-Z_$0-9]*)"
694+
684695
def replace_params(match_object):
685696
key = match_object.group(1)
686697
if key in tune_params:
687698
return 'params["' + key + '"]'
688699
else:
689700
return key
701+
690702
parsed = ") and (".join([re.sub(regex_match_variable, replace_params, res) for res in restrictions])
691703

692704
# tidy up the code by removing the last suffix and unnecessary spaces
@@ -709,6 +721,7 @@ def compile_restrictions(restrictions: list, tune_params: dict):
709721

710722

711723
class NpEncoder(json.JSONEncoder):
724+
712725
def default(self, obj):
713726
if isinstance(obj, np.integer):
714727
return int(obj)
@@ -794,7 +807,11 @@ def process_cache(cache, kernel_options, tuning_options, runner):
794807
elif not all([i == j for i, j in zip(cached_data["problem_size"], kernel_options.problem_size)]):
795808
raise ValueError("Cannot load cache which contains results for different problem_size")
796809
if cached_data["tune_params_keys"] != list(tuning_options.tune_params.keys()):
797-
raise ValueError("Cannot load cache which contains results obtained with different tunable parameters")
810+
if all(key in tuning_options.tune_params for key in cached_data["tune_params_keys"]):
811+
raise ValueError(f"All tunable parameters are present, but the order is wrong. \
812+
Cache has order: {cached_data['tune_params_keys']}, tuning_options has: {list(tuning_options.tune_params.keys())}")
813+
raise ValueError(f"Cannot load cache which contains results obtained with different tunable parameters. \
814+
Cache has: {cached_data['tune_params_keys']}, tuning_options has: {list(tuning_options.tune_params.keys())}")
798815

799816
tuning_options.cachefile = cache
800817
tuning_options.cache = cached_data["cache"]
@@ -817,9 +834,11 @@ def read_cache(cache, open_cache=True):
817834
with open(cache, "w") as cachefile:
818835
cachefile.write(filestr[:-3] + ",")
819836

820-
error_configs = {"InvalidConfig": InvalidConfig(),
821-
"CompilationFailedConfig": CompilationFailedConfig(),
822-
"RuntimeFailedConfig": RuntimeFailedConfig()}
837+
error_configs = {
838+
"InvalidConfig": InvalidConfig(),
839+
"CompilationFailedConfig": CompilationFailedConfig(),
840+
"RuntimeFailedConfig": RuntimeFailedConfig()
841+
}
823842

824843
# replace strings with ErrorConfig instances
825844
cache_data = json.loads(filestr)
@@ -864,7 +883,7 @@ def JSONconverter(obj):
864883

865884
# Convert ErrorConfig objects to string, wanted to do this inside the JSONconverter but couldn't get it to work
866885
output_params = params.copy()
867-
for k,v in output_params.items():
886+
for k, v in output_params.items():
868887
if isinstance(v, ErrorConfig):
869888
output_params[k] = str(v)
870889

0 commit comments

Comments
 (0)