11#= ===================
22 LOGISTIC CLASSIFIER
3- =================== =#
3+ =================== =#
44
55"""
6- Logistic Classifier (typically called "Logistic Regression"). This model is
7- a standard classifier for both binary and multiclass classification.
8- In the binary case it corresponds to the LogisticLoss, in the multiclass to the
9- Multinomial (softmax) loss. An elastic net penalty can be applied with
10- overall objective function
6+ $(doc_header (LogisticClassifier))
117
12- ``L(y, Xθ) + n⋅λ|θ|₂²/2 + n⋅γ|θ|₁``
8+ This model is more commonly known as "logistic regression". It is a standard classifier
9+ for both binary and multiclass classification. The objective function applies either a
10+ logistic loss (binary target) or multinomial (softmax) loss, and has a mixed L1/L2
11+ penalty:
1312
14- where ``L`` is either the logistic or multinomial loss and ``λ`` and ``γ`` indicate
15- the strength of the L2 (resp. L1) regularisation components and
16- ``n`` is the number of samples `size(X, 1)`.
17- With `scale_penalty_with_samples = false` the objective function is
18- ``L(y, Xθ) + λ|θ|₂²/2 + γ|θ|₁``
13+ ``L(y, Xθ) + n⋅λ|θ|₂²/2 + n⋅γ|θ|₁``.
1914
20- ## Parameters
15+ Here ``L`` is either `MLJLinearModels.LogisticLoss` or `MLJLinearModels.MultiClassLoss`,
16+ ``λ`` and ``γ`` indicate
17+ the strength of the L2 (resp. L1) regularization components and
18+ ``n`` is the number of training observations.
19+
20+ With `scale_penalty_with_samples = false` the objective function is instead
21+
22+ ``L(y, Xθ) + λ|θ|₂²/2 + γ|θ|₁``.
23+
24+ # Training data
25+
26+ In MLJ or MLJBase, bind an instance `model` to data with
27+
28+ mach = machine(model, X, y)
29+
30+ where:
31+
32+ - `X` is any table of input features (eg, a `DataFrame`) whose columns
33+ have `Continuous` scitype; check column scitypes with `schema(X)`
34+
35+ - `y` is the target, which can be any `AbstractVector` whose element
36+ scitype is `<:OrderedFactor` or `<:Multiclass`; check the scitype
37+ with `scitype(y)`
38+
39+ Train the machine using `fit!(mach, rows=...)`.
40+
41+
42+ # Hyperparameters
2143
2244$TYPEDFIELDS
2345
2446$(example_docstring (" LogisticClassifier" , nclasses = 2 ))
47+
48+ See also [`MultinomialClassifier`](@ref).
49+
2550"""
2651@with_kw_noshow mutable struct LogisticClassifier <: MMI.Probabilistic
27- " strength of the regulariser if `penalty` is `:l2` or `:l1` and strength of the L2
28- regulariser if `penalty` is `:en`."
52+ " strength of the regularizer if `penalty` is `:l2` or `:l1` and strength of the L2
53+ regularizer if `penalty` is `:en`."
2954 lambda:: Real = eps ()
30- " strength of the L1 regulariser if `penalty` is `:en`."
55+ " strength of the L1 regularizer if `penalty` is `:en`."
3156 gamma:: Real = 0.0
3257 " the penalty to use, either `:l2`, `:l1`, `:en` (elastic net) or `:none`."
3358 penalty:: SymStr = :l2
@@ -37,7 +62,18 @@ $(example_docstring("LogisticClassifier", nclasses = 2))
3762 penalize_intercept:: Bool = false
3863 " whether to scale the penalty with the number of samples."
3964 scale_penalty_with_samples:: Bool = true
40- " type of solver to use, default if `nothing`."
65+ """ some instance of `MLJLinearModels.S` where `S` is one of: `LBFGS`, `Newton`,
66+ `NewtonCG`, `ProxGrad`; but subject to the following restrictions:
67+
68+ - If `gamma > 0` (L1 norm penalized) then only `ProxGrad` is allowed.
69+
70+ - Unless `scitype(y) <: Finite{2}` (binary target) `Newton` is disallowed.
71+
72+ If `solver = nothing` (default) then `ProxGrad(accel=true)` (FISTA) is used,
73+ unless `gamma = 0`, in which case `LBFGS()` is used.
74+
75+ Solver aliases: `FISTA(; kwargs...) = ProxGrad(accel=true, kwargs...)`,
76+ `ISTA(; kwargs...) = ProxGrad(accel=false, kwargs...)`"""
4177 solver:: Option{Solver} = nothing
4278end
4379
@@ -50,27 +86,49 @@ glr(m::LogisticClassifier, nclasses::Integer) =
5086 scale_penalty_with_samples= m. scale_penalty_with_samples,
5187 nclasses= nclasses)
5288
53- descr (:: Type{LogisticClassifier} ) = " Classifier corresponding to the loss function ``L(y, Xθ) + λ|θ|₂²/2 + γ|θ|₁`` where `L` is the logistic loss."
54-
5589#= ======================
5690 MULTINOMIAL CLASSIFIER
5791 ====================== =#
5892
5993"""
60- See `LogisticClassifier`, it's the same except that multiple classes are assumed
61- by default. The other parameters are the same.
94+ $(doc_header (MultinomialClassifier))
95+
96+ This model coincides with [`LogisticClassifier`](@ref), except certain optimizations
97+ possible in the special binary case will not be applied. Its hyperparameters are
98+ identical.
99+
100+ # Training data
101+
102+ In MLJ or MLJBase, bind an instance `model` to data with
103+
104+ mach = machine(model, X, y)
105+
106+ where:
107+
108+ - `X` is any table of input features (eg, a `DataFrame`) whose columns
109+ have `Continuous` scitype; check column scitypes with `schema(X)`
110+
111+ - `y` is the target, which can be any `AbstractVector` whose element
112+ scitype is `<:OrderedFactor` or `<:Multiclass`; check the scitype
113+ with `scitype(y)`
114+
115+ Train the machine using `fit!(mach, rows=...)`.
62116
63- ## Parameters
117+
118+ # Hyperparameters
64119
65120$TYPEDFIELDS
66121
67- $(example_docstring (" LogisticClassifier" , nclasses = 3 ))
122+ $(example_docstring (" MultinomialClassifier" , nclasses = 3 ))
123+
124+ See also [`LogisticClassifier`](@ref).
125+
68126"""
69127@with_kw_noshow mutable struct MultinomialClassifier <: MMI.Probabilistic
70- " strength of the regulariser if `penalty` is `:l2` or `:l1`.
71- Strength of the L2 regulariser if `penalty` is `:en`."
128+ " strength of the regularizer if `penalty` is `:l2` or `:l1`.
129+ Strength of the L2 regularizer if `penalty` is `:en`."
72130 lambda:: Real = eps ()
73- " strength of the L1 regulariser if `penalty` is `:en`."
131+ " strength of the L1 regularizer if `penalty` is `:en`."
74132 gamma:: Real = 0.0
75133 " the penalty to use, either `:l2`, `:l1`, `:en` (elastic net) or `:none`."
76134 penalty:: SymStr = :l2
@@ -80,7 +138,18 @@ $(example_docstring("LogisticClassifier", nclasses = 3))
80138 penalize_intercept:: Bool = false
81139 " whether to scale the penalty with the number of samples."
82140 scale_penalty_with_samples:: Bool = true
83- " type of solver to use, default if `nothing`."
141+ """ some instance of `MLJLinearModels.S` where `S` is one of: `LBFGS`,
142+ `NewtonCG`, `ProxGrad`; but subject to the following restrictions:
143+
144+ - If `gamma > 0` (L1 norm penalized) then `ProxGrad` is disallowed.
145+
146+ - Unless `scitype(y) <: Finite{2}` (binary target) `Newton` is disallowed.
147+
148+ If `solver = nothing` (default) then `ProxGrad(accel=true)` (FISTA) is used,
149+ unless `gamma = 0`, in which case `LBFGS()` is used.
150+
151+ Solver aliases: `FISTA(; kwargs...) = ProxGrad(accel=true, kwargs...)`,
152+ `ISTA(; kwargs...) = ProxGrad(accel=false, kwargs...)`"""
84153 solver:: Option{Solver} = nothing
85154end
86155
@@ -91,7 +160,3 @@ glr(m::MultinomialClassifier, nclasses::Integer) =
91160 penalize_intercept= m. penalize_intercept,
92161 scale_penalty_with_samples= m. scale_penalty_with_samples,
93162 nclasses= nclasses)
94-
95- descr (:: Type{MultinomialClassifier} ) =
96- " Classifier corresponding to the loss function " *
97- " ``L(y, Xθ) + λ|θ|₂²/2 + γ|θ|₁`` where `L` is the multinomial loss."
0 commit comments