Skip to content

Commit b1cc085

Browse files
authored
Add more lookahead optimizer tests (#53)
1 parent 57d63c5 commit b1cc085

File tree

5 files changed

+31
-5
lines changed

5 files changed

+31
-5
lines changed

tests/test_basic.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ def ids(v):
3838
return n
3939

4040

41+
def build_lookahead(*a, **kw):
42+
base = optim.Yogi(*a, **kw)
43+
return optim.Lookahead(base)
44+
45+
4146
optimizers = [
4247
(
4348
optim.NovoGrad,
@@ -51,6 +56,7 @@ def ids(v):
5156
(optim.AdaBound, {'lr': 1.0}, 800),
5257
(optim.Yogi, {'lr': 1.0}, 500),
5358
(optim.AccSGD, {'lr': 0.015}, 800),
59+
(build_lookahead, {'lr': 1.0}, 500),
5460
]
5561

5662

@@ -69,3 +75,6 @@ def test_benchmark_function(case, optimizer_config):
6975
f.backward(retain_graph=True)
7076
optimizer.step()
7177
assert torch.allclose(x, x_min, atol=0.001)
78+
79+
name = optimizer.__class__.__name__
80+
assert name in optimizer.__repr__()

tests/test_optimizer_with_nn.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ def ids(v):
4545
return f'{v[0].__name__} {v[1:]}'
4646

4747

48+
def build_lookahead(*a, **kw):
49+
base = optim.Yogi(*a, **kw)
50+
return optim.Lookahead(base)
51+
52+
4853
optimizers = [
4954
(optim.NovoGrad, {'lr': 0.01, 'weight_decay': 1e-3}, 200),
5055
(optim.Lamb, {'lr': 0.01, 'weight_decay': 1e-3}, 200),
@@ -55,6 +60,7 @@ def ids(v):
5560
(optim.Yogi, {'lr': 0.1, 'weight_decay': 1e-3}, 200),
5661
(optim.RAdam, {'lr': 1.0, 'weight_decay': 1e-3}, 200),
5762
(optim.AccSGD, {'lr': 1.0, 'weight_decay': 1e-3}, 200),
63+
(build_lookahead, {'lr': 0.1, 'weight_decay': 1e-3}, 200),
5864
]
5965

6066

torch_optimizer/lookahead.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class Lookahead(Optimizer):
2424
Example:
2525
>>> import torch_optimizer as optim
2626
>>> yogi = optim.Yogi(model.parameters(), lr=0.1)
27-
>>> optimizer = optim.Lookahead(yogi, k=5)
27+
>>> optimizer = optim.Lookahead(yogi, k=5, alpha=0.5)
2828
>>> optimizer.zero_grad()
2929
>>> loss_fn(model(input), target).backward()
3030
>>> optimizer.step()
@@ -116,3 +116,14 @@ def load_state_dict(self, state_dict: State) -> None:
116116
def zero_grad(self) -> None:
117117
r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
118118
self.optimizer.zero_grad()
119+
120+
def __repr__(self) -> str:
121+
base_str = self.optimizer.__repr__()
122+
format_string = self.__class__.__name__ + ' ('
123+
format_string += '\n'
124+
format_string += f'k: {self.k}\n'
125+
format_string += f'alpha: {self.alpha}\n'
126+
format_string += base_str
127+
format_string += '\n'
128+
format_string += ')'
129+
return format_string

torch_optimizer/novograd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from torch.optim import Optimizer
2+
from torch.optim.optimizer import Optimizer
33

44
from .types import Betas2, OptFloat, OptLossClosure, Params
55

torch_optimizer/sgdw.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ def __init__(
3535
self,
3636
params: Params,
3737
lr: float = 1e-3,
38-
momentum: float = 0,
39-
dampening: float = 0,
40-
weight_decay: float = 1e-2,
38+
momentum: float = 0.0,
39+
dampening: float = 0.0,
40+
weight_decay: float = 0.0,
4141
nesterov: bool = False,
4242
) -> None:
4343
if not 0.0 <= lr:

0 commit comments

Comments
 (0)