Skip to content

Commit e8a9ee2

Browse files
Add Lion optimizer (#504)
* Create lion.py Add the Lion optimizer from https://arxiv.org/pdf/2302.06675.pdf * Update __init__.py * Update test_basic.py * Update test_optimizer.py * Update test_optimizer_with_nn.py * Update test_optimizer_with_nn.py * Update lion.py * Update test_optimizer_with_nn.py * Update test_param_validation.py
1 parent efe2ffd commit e8a9ee2

File tree

6 files changed

+104
-1
lines changed

6 files changed

+104
-1
lines changed

tests/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ def build_lookahead(*a, **kw):
7070
(optim.Adahessian, {'lr': 0.15, 'hessian_power': 0.6, 'seed': 0}, 900),
7171
(optim.MADGRAD, {'lr': 0.02}, 500),
7272
(optim.LARS, {'lr': 0.002, 'momentum': 0.91}, 900),
73+
(optim.Lion, {'lr': 0.025}, 3600),
7374
]
7475

75-
7676
@pytest.mark.parametrize('case', cases, ids=ids)
7777
@pytest.mark.parametrize('optimizer_config', optimizers, ids=ids)
7878
def test_benchmark_function(case, optimizer_config):

tests/test_optimizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def build_lookahead(*a, **kw):
9494
optim.SWATS,
9595
optim.Shampoo,
9696
optim.Yogi,
97+
optim.Lion,
9798
]
9899

99100

tests/test_optimizer_with_nn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def build_lookahead(*a, **kw):
8989
),
9090
(optim.Yogi, {'lr': 0.1, 'weight_decay': 1e-3}, 200),
9191
(optim.Adahessian, {'lr': 0.1, 'weight_decay': 1e-3}, 200),
92+
(optim.Lion, {'lr': 0.1, 'weight_decay': 1e-3}, 200),
9293
]
9394

9495

tests/test_param_validation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def test_sparse_not_supported(optimizer_class):
5555
optim.SWATS,
5656
optim.Shampoo,
5757
optim.Yogi,
58+
optim.Lion,
5859
]
5960

6061

@@ -118,6 +119,7 @@ def test_eps_validation(optimizer_class):
118119
optim.SWATS,
119120
optim.Shampoo,
120121
optim.Yogi,
122+
optim.Lion,
121123
]
122124

123125

@@ -141,6 +143,7 @@ def test_weight_decay_validation(optimizer_class):
141143
optim.QHAdam,
142144
optim.RAdam,
143145
optim.Yogi,
146+
optim.Lion,
144147
]
145148

146149

torch_optimizer/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from .shampoo import Shampoo
4545
from .swats import SWATS
4646
from .yogi import Yogi
47+
from .lion import Lion
4748

4849
__all__ = (
4950
'A2GradExp',
@@ -76,6 +77,7 @@
7677
'SWATS',
7778
'Shampoo',
7879
'Yogi',
80+
'Lion',
7981
# utils
8082
'get',
8183
)
@@ -107,6 +109,7 @@
107109
SWATS,
108110
Shampoo,
109111
Yogi,
112+
Lion,
110113
] # type: List[Type[Optimizer]]
111114

112115

torch_optimizer/lion.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import torch
2+
from torch.optim.optimizer import Optimizer
3+
4+
from .types import OptFloat, OptLossClosure, Params, Betas2
5+
6+
__all__ = ("Lion",)
7+
8+
9+
class Lion(Optimizer):
10+
r"""Implements Lion algorithm.
11+
12+
Addapted from https://github.com/google/automl/tree/master/lion
13+
14+
The Lion - EvoLved SIgn MOmeNtum - algorithm was proposed in
15+
https://arxiv.org/pdf/2302.06675.pdf.
16+
Lion aims to be more memory efficient than Adam by only tracking momentum.
17+
18+
Caveats: As detailed in the paper, Lion requires a smaller learning rate lr,
19+
and larger decoupled weight decay to maintain effective weight decay strength.
20+
Also, the gain of Lion increases with the batch size.
21+
Furthermore, Lion was not found to outperform AdamW on some large language
22+
and text/image datasets.
23+
24+
Arguments:
25+
params: iterable of parameters to optimize or dicts defining
26+
parameter groups
27+
lr: learning rate (default: 1e-3)
28+
betas: coefficients used for computing
29+
running averages of gradient and its square (default: (0.95, 0))
30+
weight_decay: weight decay (L2 penalty) (default: 0)
31+
32+
Example:
33+
>>> import torch_optimizer as optim
34+
>>> optimizer = optim.Lion(model.parameters(), lr=0.001)
35+
>>> optimizer.zero_grad()
36+
>>> loss_fn(model(input), target).backward()
37+
>>> optimizer.step()
38+
"""
39+
40+
def __init__(
41+
self,
42+
params: Params,
43+
lr: float = 1e-4,
44+
betas: Betas2 = (0.9, 0.99),
45+
weight_decay: float = 0.0,
46+
):
47+
48+
if lr <= 0.0:
49+
raise ValueError("Invalid learning rate: {}".format(lr))
50+
if not 0.0 <= betas[0] < 1.0:
51+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
52+
if not 0.0 <= betas[1] < 1.0:
53+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
54+
if weight_decay < 0:
55+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
56+
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
57+
super().__init__(params, defaults)
58+
59+
@torch.no_grad()
60+
def step(self, closure: OptLossClosure = None) -> OptFloat:
61+
r"""Performs a single optimization step.
62+
63+
Arguments:
64+
closure: A closure that reevaluates the model and returns the loss.
65+
"""
66+
loss = None
67+
if closure is not None:
68+
with torch.enable_grad():
69+
loss = closure()
70+
71+
for group in self.param_groups:
72+
for p in group["params"]:
73+
if p.grad is None:
74+
continue
75+
76+
# Perform stepweight decay
77+
p.data.mul_(1 - group["lr"] * group["weight_decay"])
78+
79+
grad = p.grad
80+
state = self.state[p]
81+
# State initialization
82+
if len(state) == 0:
83+
# Exponential moving average of gradient values
84+
state["exp_avg"] = torch.zeros_like(p)
85+
86+
exp_avg = state["exp_avg"]
87+
beta1, beta2 = group["betas"]
88+
89+
# Weight update
90+
update = exp_avg * beta1 + grad * (1 - beta1)
91+
p.add_(torch.sign(update), alpha=-group["lr"])
92+
# Decay the momentum running average coefficient
93+
exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
94+
95+
return loss

0 commit comments

Comments
 (0)