Skip to content

Commit 1e0cbf7

Browse files
committed
add tests for changing converters
1 parent 35d64f5 commit 1e0cbf7

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

specparam/tests/models/test_model.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def test_fit_nk():
7676

7777
# Check model results - gaussian parameters
7878
for ii, gauss in enumerate(groupby(gauss_params, 3)):
79-
assert np.allclose(gauss, tfm.results.params.periodic.get_params('fit')[ii], [2.0, 0.5, 1.0])
79+
assert np.allclose(gauss, \
80+
tfm.results.params.periodic.get_params('fit')[ii], [2.0, 0.5, 1.0])
8081

8182
def test_fit_nk_noise():
8283
"""Test fit on noisy data, to make sure nothing breaks."""
@@ -107,7 +108,8 @@ def test_fit_knee():
107108

108109
# Check model results - gaussian parameters
109110
for ii, gauss in enumerate(groupby(gauss_params, 3)):
110-
assert np.allclose(gauss, tfm.results.params.periodic.get_params('fit')[ii], [2.0, 0.5, 1.0])
111+
assert np.allclose(gauss, \
112+
tfm.results.params.periodic.get_params('fit')[ii], [2.0, 0.5, 1.0])
111113

112114
def test_fit_default_metrics():
113115
"""Test computing metrics, post model fitting."""
@@ -138,6 +140,24 @@ def test_fit_custom_metrics():
138140
assert key in metrics
139141
assert isinstance(val, float)
140142

143+
def test_fit_null_conversions(tfm):
144+
145+
null_converters = tfm.modes.get_params('dict')
146+
ntfm = SpectralModel(converters=null_converters)
147+
148+
ntfm.fit(tfm.data.freqs, tfm.get_data('full', 'linear'))
149+
assert np.all(np.isnan(ntfm.results.get_params('aperiodic', version='converted')))
150+
assert np.all(np.isnan(ntfm.results.get_params('periodic', version='converted')))
151+
152+
def test_fit_custom_conversions(tfm):
153+
154+
converters = {'periodic' : {'pw' : 'lin_sub'}}
155+
ntfm = SpectralModel(converters=converters)
156+
157+
ntfm.fit(tfm.data.freqs, tfm.get_data('full', 'linear'))
158+
assert not np.array_equal(
159+
tfm.results.get_params('periodic', 'pw'), ntfm.results.get_params('periodic', 'pw'))
160+
141161
def test_checks():
142162
"""Test various checks, errors and edge cases for model fitting.
143163
This tests all the input checking done in `_prepare_data`.

0 commit comments

Comments
 (0)