Skip to content

Commit 20b6dc2

Browse files
committed
add load funcs & io updates
1 parent 6951711 commit 20b6dc2

File tree

4 files changed

+153
-14
lines changed

4 files changed

+153
-14
lines changed

specparam/core/io.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,125 @@ def save_event(event, file_name, file_path=None, append=False,
236236
save_settings=save_settings, save_data=save_data)
237237

238238

239+
def load_model(file_name, file_path=None, regenerate=True, model=None):
240+
"""Load a SpectralModel object.
241+
242+
Parameters
243+
----------
244+
Parameters
245+
----------
246+
file_name : str
247+
File(s) to load data from.
248+
file_path : str, optional
249+
Path to directory to load from. If None, loads from current directory.
250+
regenerate : bool, optional, default: True
251+
Whether to regenerate the model fit from the loaded data, if data is available.
252+
model : SpectralModel
253+
xx
254+
255+
Returns
256+
-------
257+
model : SpectralModel
258+
Loaded model object with data from file.
259+
"""
260+
261+
# Check for model object, import (avoid circular) and initialize if not
262+
if not model:
263+
from specparam.objs import SpectralModel
264+
model = SpectralModel()
265+
266+
model.load(file_name, file_path, regenerate)
267+
268+
return model
269+
270+
271+
def load_group(file_name, file_path=None, group=None):
272+
"""Load a SpectralGroupModel object.
273+
274+
Parameters
275+
----------
276+
file_name : str
277+
File(s) to load data from.
278+
file_path : str, optional
279+
Path to directory to load from. If None, loads from current directory.
280+
group : SpectralGroupModel
281+
xx
282+
283+
Returns
284+
-------
285+
group : SpectralGroupModel
286+
Loaded model object with data from file.
287+
"""
288+
289+
# Check for model object, import (avoid circular) and initialize if not
290+
if not group:
291+
from specparam.objs import SpectralGroupModel
292+
group = SpectralGroupModel()
293+
294+
group.load(file_name, file_path)
295+
296+
return group
297+
298+
299+
def load_time(file_name, file_path=None, peak_org=None, time=None):
300+
"""Load a SpectralTimeModel object.
301+
302+
Parameters
303+
----------
304+
file_name : str
305+
File(s) to load data from.
306+
file_path : str, optional
307+
Path to directory to load from. If None, loads from current directory.
308+
peak_org : int or Bands, optional
309+
How to organize peaks.
310+
If int, extracts the first n peaks.
311+
If Bands, extracts peaks based on band definitions.
312+
313+
Returns
314+
-------
315+
time : SpectralTimeModel
316+
Loaded model object with data from file.
317+
"""
318+
319+
# Check for model object, import (avoid circular) and initialize if not
320+
if not time:
321+
from specparam.objs import SpectralTimeModel
322+
time = SpectralTimeModel()
323+
324+
time.load(file_name, file_path, peak_org)
325+
326+
return time
327+
328+
def load_event(file_name, file_path=None, peak_org=None, event=None):
329+
"""Load a SpectralTimeEventModel object.
330+
331+
Parameters
332+
----------
333+
file_name : str
334+
File(s) to load data from.
335+
file_path : str, optional
336+
Path to directory to load from. If None, loads from current directory.
337+
peak_org : int or Bands, optional
338+
How to organize peaks.
339+
If int, extracts the first n peaks.
340+
If Bands, extracts peaks based on band definitions.
341+
342+
Returns
343+
-------
344+
event : SpectralTimeEventModel
345+
Loaded model object with data from file.
346+
"""
347+
348+
# Check for model object, import (avoid circular) and initialize if not
349+
if not event:
350+
from specparam.objs import SpectralTimeEventModel
351+
event = SpectralTimeEventModel()
352+
353+
event.load(file_name, file_path, peak_org)
354+
355+
return event
356+
357+
239358
def load_json(file_name, file_path):
240359
"""Load json file.
241360

specparam/tests/core/test_io.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ def test_save_model_str(tfm):
4141
"""Check saving model object data, with file specifiers as strings."""
4242

4343
# Test saving out each set of save elements
44-
file_name_res = 'test_res'
45-
file_name_set = 'test_set'
46-
file_name_dat = 'test_dat'
44+
file_name_res = 'test_model_res'
45+
file_name_set = 'test_model_set'
46+
file_name_dat = 'test_model_dat'
4747

4848
save_model(tfm, file_name_res, TEST_DATA_PATH, False, True, False, False)
4949
save_model(tfm, file_name_set, TEST_DATA_PATH, False, False, True, False)
@@ -54,14 +54,14 @@ def test_save_model_str(tfm):
5454
assert os.path.exists(TEST_DATA_PATH / (file_name_dat + '.json'))
5555

5656
# Test saving out all save elements
57-
file_name_all = 'test_all'
57+
file_name_all = 'test_model_all'
5858
save_model(tfm, file_name_all, TEST_DATA_PATH, False, True, True, True)
5959
assert os.path.exists(TEST_DATA_PATH / (file_name_all + '.json'))
6060

6161
def test_save_model_append(tfm):
6262
"""Check saving fm data, appending to a file."""
6363

64-
file_name = 'test_append'
64+
file_name = 'test_model_append'
6565

6666
save_model(tfm, file_name, TEST_DATA_PATH, True, True, True, True)
6767
save_model(tfm, file_name, TEST_DATA_PATH, True, True, True, True)
@@ -71,7 +71,7 @@ def test_save_model_append(tfm):
7171
def test_save_model_fobj(tfm):
7272
"""Check saving fm data, with file object file specifier."""
7373

74-
file_name = 'test_fileobj'
74+
file_name = 'test_model_fileobj'
7575

7676
# Save, using file-object: three successive lines with three possible save settings
7777
with open(TEST_DATA_PATH / (file_name + '.json'), 'w') as f_obj:
@@ -163,12 +163,32 @@ def test_save_event(tfe):
163163
for ind in range(len(tfe)):
164164
assert os.path.exists(TEST_DATA_PATH / (file_name_all + '_' + str(ind) + '.json'))
165165

166+
def test_load_model():
167+
168+
tmodel = load_model('test_model_all', TEST_DATA_PATH)
169+
assert tmodel
170+
171+
def test_load_group():
172+
173+
tgroup = load_group('test_group_all', TEST_DATA_PATH)
174+
assert tgroup
175+
176+
def test_load_time():
177+
178+
ttime = load_time('test_time_all', TEST_DATA_PATH)
179+
assert ttime
180+
181+
def test_load_event():
182+
183+
tevent = load_event('test_event_all', TEST_DATA_PATH)
184+
assert tevent
185+
166186
def test_load_json_str():
167187
"""Test loading JSON file, with str file specifier.
168188
Loads files from test_save_model_str.
169189
"""
170190

171-
file_name = 'test_all'
191+
file_name = 'test_model_all'
172192

173193
data = load_json(file_name, TEST_DATA_PATH)
174194

@@ -179,7 +199,7 @@ def test_load_json_fobj():
179199
Loads files from test_save_model_str.
180200
"""
181201

182-
file_name = 'test_all'
202+
file_name = 'test_model_all'
183203

184204
with open(TEST_DATA_PATH / (file_name + '.json'), 'r') as f_obj:
185205
data = load_json(f_obj, '')
@@ -201,7 +221,7 @@ def test_load_file_contents():
201221
Note that is this test fails, it likely stems from an issue from saving.
202222
"""
203223

204-
file_name = 'test_all'
224+
file_name = 'test_model_all'
205225
loaded_data = load_json(file_name, TEST_DATA_PATH)
206226

207227
# Check settings

specparam/tests/objs/test_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def test_load():
182182

183183
# Test loading just results
184184
tfm = SpectralModel(verbose=False)
185-
file_name_res = 'test_res'
185+
file_name_res = 'test_model_res'
186186
tfm.load(file_name_res, TEST_DATA_PATH)
187187
# Check that result attributes get filled
188188
for result in OBJ_DESC['results']:
@@ -196,7 +196,7 @@ def test_load():
196196

197197
# Test loading just settings
198198
tfm = SpectralModel(verbose=False)
199-
file_name_set = 'test_set'
199+
file_name_set = 'test_model_set'
200200
tfm.load(file_name_set, TEST_DATA_PATH)
201201
for setting in OBJ_DESC['settings']:
202202
assert getattr(tfm, setting) is not None
@@ -207,7 +207,7 @@ def test_load():
207207

208208
# Test loading just data
209209
tfm = SpectralModel(verbose=False)
210-
file_name_dat = 'test_dat'
210+
file_name_dat = 'test_model_dat'
211211
tfm.load(file_name_dat, TEST_DATA_PATH)
212212
assert tfm.power_spectrum is not None
213213
# Test that settings and results are None
@@ -218,7 +218,7 @@ def test_load():
218218

219219
# Test loading all elements
220220
tfm = SpectralModel(verbose=False)
221-
file_name_all = 'test_all'
221+
file_name_all = 'test_model_all'
222222
tfm.load(file_name_all, TEST_DATA_PATH)
223223
for result in OBJ_DESC['results']:
224224
assert not np.all(np.isnan(getattr(tfm, result)))

specparam/tests/utils/test_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
def test_load_model():
1717

18-
file_name = 'test_all'
18+
file_name = 'test_model_all'
1919

2020
tfm = load_model(file_name, TEST_DATA_PATH)
2121

0 commit comments

Comments
 (0)