Skip to content

Commit 6163643

Browse files
committed
Update model tuning based on cross-validation, incorporate TreeParzen.jl
1 parent ece1777 commit 6163643

File tree

1 file changed

+61
-53
lines changed

1 file changed

+61
-53
lines changed

src/interface.jl

Lines changed: 61 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -265,14 +265,19 @@ end
265265

266266
"""
267267
model_tuning_ic!(model, space, regularizer; trials = 100, ic = :bic, verbose = false,
268-
kwargs...) -> (model_opt, index_opt)
268+
kwargs...) -> (model, best)
269269
270270
Search for the optimal regularizer in search space `space` for the dynamic factor model
271271
`model` using information criterion `ic` and a Tree Parzen estimator performing number of
272272
trials given by `trials`, where `regularizer` is a function that creates the regularizer
273273
from a dictionary of hyperparameters. If `verbose` is true, a summary of model tuning and
274274
progress of the search is printed. Additional keyword arguments `kwargs` are passed to the
275275
`fit!` function.
276+
277+
The search space `space` should be provided as a dictionary where the values are HP
278+
functions to determine the sampling scheme for the Tree Parzen estimator. The HP module is
279+
re-exported from TreeParzen.jl for covenience, for more details on the implementation and
280+
options see https://github.com/IQVIA-ML/TreeParzen.jl.
276281
"""
277282
function model_tuning_ic!(model::DynamicFactorModel, space::Dict, regularizer::Function;
278283
trials::Integer = 100, ic::Symbol = :bic, verbose::Bool = false,
@@ -317,29 +322,40 @@ function model_tuning_ic!(model::DynamicFactorModel, space::Dict, regularizer::F
317322
end
318323

319324
"""
320-
model_tuning_cv!(model, regularizers, blocks, periods; metric = :mse, parallel = false,
321-
verbose = false, kwargs...) -> (model_opt, index_opt)
322-
323-
Search for the optimal regularizer in `regularizers` for the dynamic factor
324-
model `model` using cross-validation with out-of-sample consisting of `blocks`
325-
blocks with `periods` period ahead forecasts and metric `metric`. If `parallel`
326-
is true, the search is performed in parallel. If `verbose` is true, a summary of
327-
model tuning and progress of the search is printed. Additional keyword arguments
328-
`kwargs` are passed to the `fit!` function.
325+
model_tuning_cv!(model, space, regularizer, blocks, periods; trials = 100, metric = :mse,
326+
verbose = false, kwargs...) -> (model, best)
327+
328+
Search for the optimal regularizer in search space `space` for the dynamic factor model
329+
`model` using cross-validation with the out-of-sample consisting of `blocks` blocks with
330+
`periods` period ahead forecasts, metric `metric` and a Tree Parzen estimator performing
331+
number of trials given by `trials`, where `regularizer` is a function that creates the
332+
regularizer from a dictionary of hyperparameters. If `verbose` is true, a summary of model
333+
tuning and progress of the search is printed. Additional keyword arguments `kwargs` are
334+
passed to the `fit!` function.
335+
336+
The search space `space` should be provided as a dictionary where the values are HP
337+
functions to determine the sampling scheme for the Tree Parzen estimator. The HP module is
338+
re-exported from TreeParzen.jl for covenience, for more details on the implementation and
339+
options see https://github.com/IQVIA-ML/TreeParzen.jl.
329340
"""
330-
function model_tuning_cv!(model::DynamicFactorModel, regularizers::AbstractArray,
331-
blocks::Integer, periods::Integer; metric::Symbol = :mse,
332-
parallel::Bool = false, verbose::Bool = false, kwargs...)
333-
metric (:mse, :mae) && error("Accuracy matric $metric not supported.")
341+
function model_tuning_cv!(model::DynamicFactorModel, space::Dict, regularizer::Function,
342+
blocks::Integer, periods::Integer; trials::Integer = 100,
343+
metric::Symbol = :mse, verbose::Bool = false, kwargs...)
344+
if metric == :mse
345+
loss_function = abs2
346+
elseif metric == :mae
347+
loss_function = abs
348+
else
349+
error("Accuracy matric $metric not supported.")
350+
end
334351

335352
if verbose
336353
println("Model tuning summary")
337354
println("====================")
338-
println("Number of regularizers: $(length(regularizers))")
355+
println("Number of trials: $trials")
339356
println("Number of out-of-sample blocks: $blocks")
340357
println("Forecast periods per block: $periods")
341358
println("Forecast accuracy metric: $metric")
342-
println("Parallel: $(parallel ? "yes" : "no")")
343359
println("====================")
344360
end
345361

@@ -349,53 +365,45 @@ function model_tuning_cv!(model::DynamicFactorModel, regularizers::AbstractArray
349365
train_model = select_sample(model, 1:Ttrain)
350366
test_models = [select_sample(model, 1:t) for t in Ttrain:(T - periods)]
351367

352-
# model tuning
353-
map_func = parallel ? verbose ? progress_pmap : pmap : verbose ? progress_map : map
354-
θ0 = params(model)
355-
f0 = factors(model)[:, 1:Ttrain]
356-
θ = map_func(regularizers) do regularizer
357-
try
358-
params!(train_model, θ0)
359-
factors(train_model) .= f0
360-
fit!(train_model, regularizer = regularizer; kwargs...)
361-
params(train_model)
362-
catch
363-
missing
368+
# objective function
369+
function objective(params)
370+
371+
# fit on train sample
372+
fit!(train_model, regularizer = regularizer(params))
373+
θ = params(train_model)
374+
375+
# evaluate on test samples
376+
loss = zero(eltype(data(model)))
377+
for (t, test_model) in pairs(test_models)
378+
# out-of-sample data
379+
y = view(data(model), :, (Ttrain + t):(Ttrain + t + periods - 1))
380+
# loss
381+
params!(test_model, θ)
382+
loss += sum(loss_function, y - forecast(test_model, periods))
364383
end
384+
385+
return loss / (n * length(test_models) * periods)
365386
end
366-
avg_loss = map(θ) do θi
367-
if all(ismissing.(θi))
368-
missing
369-
else
370-
loss = zero(eltype(data(model)))
371-
for (t, test_model) in pairs(test_models)
372-
params!(test_model, θi)
373-
oos_range = (Ttrain + t):(Ttrain + t + periods - 1)
374-
if metric == :mse
375-
loss += sum(abs2,
376-
view(data(model), :, oos_range) -
377-
forecast(test_model, periods))
378-
elseif metric == :mae
379-
loss += sum(abs,
380-
view(data(model), :, oos_range) -
381-
forecast(test_model, periods))
382-
end
383-
end
384-
loss / (n * length(test_models) * periods)
387+
388+
# model tuning
389+
if verbose
390+
best = fmin(objective, space, trials)
391+
else
392+
with_logger(NullLogger()) do
393+
best = fmin(objective, space, trials)
385394
end
386395
end
387-
(avg_loss_opt, index_opt) = findmin(x -> isnan(x) ? Inf : x, skipmissing(avg_loss))
388-
fit!(model, regularizer = regularizers[index_opt]; kwargs...)
396+
397+
# refit
398+
fit!(model, regularizer = regularizer(best))
389399

390400
if verbose
391401
println("====================")
392-
println("Optimal regularizer index: $(index_opt)")
393-
println("Optimal forecast accuracy: $(avg_loss_opt)")
394-
println("Failed fits: $(sum(ismissing.(avg_loss)))")
402+
println("Optimal regularizer: $best")
395403
println("====================")
396404
end
397405

398-
return (model, index_opt)
406+
return (model, best)
399407
end
400408

401409
"""

0 commit comments

Comments
 (0)