Skip to content

Commit 456f724

Browse files
committed
Add option for different loss metrics in cross-validation
1 parent 38872c1 commit 456f724

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

src/interface.jl

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -339,30 +339,34 @@ function model_tuning_ic!(
339339
end
340340

341341
"""
342-
model_tuning_cv!(model, regularizers, blocks, periods; parallel=false, verbose=false, kwargs...) -> (model_opt, index_opt)
342+
model_tuning_cv!(model, regularizers, blocks, periods; metric=:mse, parallel=false, verbose=false, kwargs...) -> (model_opt, index_opt)
343343
344344
Search for the optimal regularizer in `regularizers` for the dynamic factor
345345
model `model` using cross-validation with out-of-sample consisting of `blocks`
346-
blocks with `periods` period ahead forecasts. If `parallel` is true, the search
347-
is performed in parallel. If `verbose` is true, a summary of model tuning and
348-
progress of the search is printed. Additional keyword arguments `kwargs` are
349-
passed to the `fit!` function.
346+
blocks with `periods` period ahead forecasts and metric `metric`. If `parallel`
347+
is true, the search is performed in parallel. If `verbose` is true, a summary of
348+
model tuning and progress of the search is printed. Additional keyword arguments
349+
`kwargs` are passed to the `fit!` function.
350350
"""
351351
function model_tuning_cv!(
352352
model::DynamicFactorModel,
353353
regularizers::AbstractArray,
354354
blocks::Integer,
355355
periods::Integer;
356+
metric::Symbol=:mse,
356357
parallel::Bool=false,
357358
verbose::Bool=false,
358359
kwargs...
359360
)
361+
metric (:mse, :mae) && error("Accuracy matric $metric not supported.")
362+
360363
if verbose
361364
println("Model tuning summary")
362365
println("====================")
363366
println("Number of regularizers: $(length(regularizers))")
364367
println("Number of out-of-sample blocks: $blocks")
365368
println("Forecast periods per block: $periods")
369+
println("Forecast accuracy metric: $metric")
366370
println("Parallel: $(parallel ? "yes" : "no")")
367371
println("====================")
368372
end
@@ -387,27 +391,31 @@ function model_tuning_cv!(
387391
missing
388392
end
389393
end
390-
msfe = map(θ) do θi
394+
avg_loss = map(θ) do θi
391395
if all(ismissing.(θi))
392396
missing
393397
else
394-
e_sq = zero(eltype(data(model)))
398+
loss = zero(eltype(data(model)))
395399
for (t, test_model) pairs(test_models)
396400
params!(test_model, θi)
397401
oos_range = (T_train + t):(T_train + t + periods - 1)
398-
e_sq += sum(abs2, view(data(model), :, oos_range) - forecast(test_model, periods))
402+
if metric == :mse
403+
loss += sum(abs2, view(data(model), :, oos_range) - forecast(test_model, periods))
404+
elseif metric == :mae
405+
loss += sum(abs, view(data(model), :, oos_range) - forecast(test_model, periods))
406+
end
399407
end
400-
e_sq / (n * length(test_models) * periods)
408+
loss / (n * length(test_models) * periods)
401409
end
402410
end
403-
(msfe_opt, index_opt) = findmin(x -> isnan(x) ? Inf : x, skipmissing(msfe))
411+
(avg_loss_opt, index_opt) = findmin(x -> isnan(x) ? Inf : x, skipmissing(avg_loss))
404412
fit!(model, regularizer=regularizers[index_opt]; kwargs...)
405413

406414
if verbose
407415
println("====================")
408416
println("Optimal regularizer index: $(index_opt)")
409-
println("Optimal forecast accuracy: $(msfe_opt)")
410-
println("Failed fits: $(sum(ismissing.(msfe)))")
417+
println("Optimal forecast accuracy: $(avg_loss_opt)")
418+
println("Failed fits: $(sum(ismissing.(avg_loss)))")
411419
println("====================")
412420
end
413421

0 commit comments

Comments
 (0)