11"""
22Helper functions for testing.
33"""
4-
54import inspect
65import os
6+ import string
77
88from matplotlib .testing .compare import compare_images
9-
109from ..exceptions import GMTImageComparisonFailure
1110
1211
13- def check_figures_equal (* , tol = 0.0 , result_dir = "result_images" ):
12+ def check_figures_equal (* , extensions = ( "png" ,), tol = 0.0 , result_dir = "result_images" ):
1413 """
1514 Decorator for test cases that generate and compare two figures.
1615
17- The decorated function must take two arguments, *fig_ref* and *fig_test*,
18- and draw the reference and test images on them. After the function
19- returns, the figures are saved and compared.
16+ The decorated function must return two arguments, *fig_ref* and *fig_test*,
17+ these two figures will then be saved and compared against each other.
2018
2119 This decorator is practically identical to matplotlib's check_figures_equal
2220 function, but adapted for PyGMT figures. See also the original code at
@@ -25,6 +23,8 @@ def check_figures_equal(*, tol=0.0, result_dir="result_images"):
2523
2624 Parameters
2725 ----------
26+ extensions : list
27+ The extensions to test. Default is ["png"].
2828 tol : float
2929 The RMS threshold above which the test is considered failed.
3030 result_dir : str
@@ -66,19 +66,30 @@ def check_figures_equal(*, tol=0.0, result_dir="result_images"):
6666 ... )
6767 >>> shutil.rmtree(path="tmp_result_images") # cleanup folder if tests pass
6868 """
69+ # pylint: disable=invalid-name
70+ ALLOWED_CHARS = set (string .digits + string .ascii_letters + "_-[]()" )
71+ KEYWORD_ONLY = inspect .Parameter .KEYWORD_ONLY
6972
7073 def decorator (func ):
74+ import pytest
7175
7276 os .makedirs (result_dir , exist_ok = True )
7377 old_sig = inspect .signature (func )
7478
75- def wrapper (* args , ** kwargs ):
79+ @pytest .mark .parametrize ("ext" , extensions )
80+ def wrapper (* args , ext = "png" , request = None , ** kwargs ):
81+ if "ext" in old_sig .parameters :
82+ kwargs ["ext" ] = ext
83+ if "request" in old_sig .parameters :
84+ kwargs ["request" ] = request
85+ try :
86+ file_name = "" .join (c for c in request .node .name if c in ALLOWED_CHARS )
87+ except AttributeError : # 'NoneType' object has no attribute 'node'
88+ file_name = func .__name__
7689 try :
7790 fig_ref , fig_test = func (* args , ** kwargs )
78- ref_image_path = os .path .join (
79- result_dir , func .__name__ + "-expected.png"
80- )
81- test_image_path = os .path .join (result_dir , func .__name__ + ".png" )
91+ ref_image_path = os .path .join (result_dir , f"{ file_name } -expected.{ ext } " )
92+ test_image_path = os .path .join (result_dir , f"{ file_name } .{ ext } " )
8293 fig_ref .savefig (ref_image_path )
8394 fig_test .savefig (test_image_path )
8495
@@ -109,9 +120,18 @@ def wrapper(*args, **kwargs):
109120 for param in old_sig .parameters .values ()
110121 if param .name not in {"fig_test" , "fig_ref" }
111122 ]
123+ if "ext" not in old_sig .parameters :
124+ parameters += [inspect .Parameter ("ext" , KEYWORD_ONLY )]
125+ if "request" not in old_sig .parameters :
126+ parameters += [inspect .Parameter ("request" , KEYWORD_ONLY )]
112127 new_sig = old_sig .replace (parameters = parameters )
113128 wrapper .__signature__ = new_sig
114129
130+ # reach a bit into pytest internals to hoist the marks from
131+ # our wrapped function
132+ new_marks = getattr (func , "pytestmark" , []) + wrapper .pytestmark
133+ wrapper .pytestmark = new_marks
134+
115135 return wrapper
116136
117137 return decorator
0 commit comments