@@ -339,30 +339,34 @@ function model_tuning_ic!(
339339end
340340
341341"""
342- model_tuning_cv!(model, regularizers, blocks, periods; parallel=false, verbose=false, kwargs...) -> (model_opt, index_opt)
342+ model_tuning_cv!(model, regularizers, blocks, periods; metric=:mse, parallel=false, verbose=false, kwargs...) -> (model_opt, index_opt)
343343
344344Search for the optimal regularizer in `regularizers` for the dynamic factor
345345model `model` using cross-validation with out-of-sample consisting of `blocks`
346- blocks with `periods` period ahead forecasts. If `parallel` is true, the search
347- is performed in parallel. If `verbose` is true, a summary of model tuning and
348- progress of the search is printed. Additional keyword arguments `kwargs` are
349- passed to the `fit!` function.
346+ blocks with `periods` period ahead forecasts and metric `metric` . If `parallel`
347+ is true, the search is performed in parallel. If `verbose` is true, a summary of
348+ model tuning and progress of the search is printed. Additional keyword arguments
349+ `kwargs` are passed to the `fit!` function.
350350"""
351351function model_tuning_cv! (
352352 model:: DynamicFactorModel ,
353353 regularizers:: AbstractArray ,
354354 blocks:: Integer ,
355355 periods:: Integer ;
356+ metric:: Symbol = :mse ,
356357 parallel:: Bool = false ,
357358 verbose:: Bool = false ,
358359 kwargs...
359360)
361+ metric ∉ (:mse , :mae ) && error (" Accuracy matric $metric not supported." )
362+
360363 if verbose
361364 println (" Model tuning summary" )
362365 println (" ====================" )
363366 println (" Number of regularizers: $(length (regularizers)) " )
364367 println (" Number of out-of-sample blocks: $blocks " )
365368 println (" Forecast periods per block: $periods " )
369+ println (" Forecast accuracy metric: $metric " )
366370 println (" Parallel: $(parallel ? " yes" : " no" ) " )
367371 println (" ====================" )
368372 end
@@ -387,27 +391,31 @@ function model_tuning_cv!(
387391 missing
388392 end
389393 end
390- msfe = map (θ) do θi
394+ avg_loss = map (θ) do θi
391395 if all (ismissing .(θi))
392396 missing
393397 else
394- e_sq = zero (eltype (data (model)))
398+ loss = zero (eltype (data (model)))
395399 for (t, test_model) ∈ pairs (test_models)
396400 params! (test_model, θi)
397401 oos_range = (T_train + t): (T_train + t + periods - 1 )
398- e_sq += sum (abs2, view (data (model), :, oos_range) - forecast (test_model, periods))
402+ if metric == :mse
403+ loss += sum (abs2, view (data (model), :, oos_range) - forecast (test_model, periods))
404+ elseif metric == :mae
405+ loss += sum (abs, view (data (model), :, oos_range) - forecast (test_model, periods))
406+ end
399407 end
400- e_sq / (n * length (test_models) * periods)
408+ loss / (n * length (test_models) * periods)
401409 end
402410 end
403- (msfe_opt , index_opt) = findmin (x -> isnan (x) ? Inf : x, skipmissing (msfe ))
411+ (avg_loss_opt , index_opt) = findmin (x -> isnan (x) ? Inf : x, skipmissing (avg_loss ))
404412 fit! (model, regularizer= regularizers[index_opt]; kwargs... )
405413
406414 if verbose
407415 println (" ====================" )
408416 println (" Optimal regularizer index: $(index_opt) " )
409- println (" Optimal forecast accuracy: $(msfe_opt ) " )
410- println (" Failed fits: $(sum (ismissing .(msfe ))) " )
417+ println (" Optimal forecast accuracy: $(avg_loss_opt ) " )
418+ println (" Failed fits: $(sum (ismissing .(avg_loss ))) " )
411419 println (" ====================" )
412420 end
413421
0 commit comments