Skip to content

Commit f53d18b

Browse files
authored
Add vizualization code to the repository (#56)
1 parent c878bef commit f53d18b

File tree

3 files changed

+228
-0
lines changed

3 files changed

+228
-0
lines changed

README.rst

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ torch-optimizer
1313
**torch-optimizer** -- collection of optimizers for PyTorch_.
1414

1515

16+
1617
Simple example
1718
--------------
1819

@@ -65,6 +66,38 @@ Supported Optimizers
6566
+-------------+-------------------------------------------------------------------------------+
6667

6768

69+
Visualisations
70+
--------------
71+
Visualisations help us to see how different algorithms deals with simple
72+
situations like: saddle points, local minima, valleys etc, and may provide
73+
interesting insights into inner workings of algorithm. Rosenbrock_ and Rastrigin_
74+
benchmark_ functions was selected, because:
75+
76+
* Rosenbrock_ (also known as banana function), is non-convex function that has
77+
one global minima `(1.0. 1.0)`. The global minimum is inside a long,
78+
narrow, parabolic shaped flat valley. To find the valley is trivial. To
79+
converge to the global minima, however, is difficult. Optimization
80+
algorithms might pay a lot of attention to one coordinate, and have
81+
problems to follow valley which is relatively flat.
82+
83+
.. image:: https://upload.wikimedia.org/wikipedia/commons/3/32/Rosenbrock_function.svg
84+
85+
* Rastrigin_ function is a non-convex and has one global minima in `(0.0, 0.0)`.
86+
Finding the minimum of this function is a fairly difficult problem due to
87+
its large search space and its large number of local minima.
88+
89+
.. image:: https://upload.wikimedia.org/wikipedia/commons/8/8b/Rastrigin_function.png
90+
91+
Each optimizer performs `501` optimization steps. Learning rate is best one found
92+
by hyper parameter search algorithm, rest of tuning parameters are default. It
93+
is very easy to extend script and tune other optimizer parameters.
94+
95+
96+
.. code::
97+
98+
python examples/viz_optimizers.py
99+
100+
68101
AccSGD
69102
------
70103

@@ -322,3 +355,6 @@ learning rate control, and has similar theoretical guarantees on convergence as
322355

323356
.. _Python: https://www.python.org
324357
.. _PyTorch: https://github.com/pytorch/pytorch
358+
.. _Rastrigin: https://en.wikipedia.org/wiki/Rastrigin_function
359+
.. _Rosenbrock: https://en.wikipedia.org/wiki/Rosenbrock_function
360+
.. _benchmark: https://en.wikipedia.org/wiki/Test_functions_for_optimization

examples/requirements-examples.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
torch==1.4.0
2+
hyperopt==0.2.3
3+
torchvision==0.5.0
4+
matplotlib==3.1.3

examples/viz_optimizers.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
import math
2+
import numpy as np
3+
import torch_optimizer as optim
4+
import torch
5+
from hyperopt import fmin, tpe, hp
6+
import matplotlib.pyplot as plt
7+
8+
9+
plt.style.use('seaborn-white')
10+
11+
12+
def rosenbrock(tensor):
13+
# https://en.wikipedia.org/wiki/Test_functions_for_optimization
14+
x, y = tensor
15+
return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2
16+
17+
18+
def rastrigin(tensor, lib=torch):
19+
# https://en.wikipedia.org/wiki/Test_functions_for_optimization
20+
x, y = tensor
21+
A = 10
22+
f = (
23+
A * 2
24+
+ (x ** 2 - A * lib.cos(x * math.pi * 2))
25+
+ (y ** 2 - A * lib.cos(y * math.pi * 2))
26+
)
27+
return f
28+
29+
30+
def execute_steps(
31+
func, initial_state, optimizer_class, optimizer_config, num_iter=500
32+
):
33+
x = torch.Tensor(initial_state).requires_grad_(True)
34+
optimizer = optimizer_class([x], **optimizer_config)
35+
steps = []
36+
steps = np.zeros((2, num_iter + 1))
37+
steps[:, 0] = np.array(initial_state)
38+
for i in range(1, num_iter + 1):
39+
optimizer.zero_grad()
40+
f = func(x)
41+
f.backward(retain_graph=True)
42+
optimizer.step()
43+
steps[:, i] = x.detach().numpy()
44+
return steps
45+
46+
47+
def objective_rastrigin(params):
48+
lr = params['lr']
49+
optimizer_class = params['optimizer_class']
50+
initial_state = (-2.0, 3.5)
51+
minimum = (0, 0)
52+
optimizer_config = dict(lr=lr)
53+
num_iter = 100
54+
steps = execute_steps(
55+
rastrigin, initial_state, optimizer_class, optimizer_config, num_iter
56+
)
57+
return (steps[0][-1] - minimum[0]) ** 2 + (steps[1][-1] - minimum[1]) ** 2
58+
59+
60+
def objective_rosenbrok(params):
61+
lr = params['lr']
62+
optimizer_class = params['optimizer_class']
63+
minimum = (1.0, 1.0)
64+
initial_state = (-2.0, 2.0)
65+
optimizer_config = dict(lr=lr)
66+
num_iter = 100
67+
steps = execute_steps(
68+
rosenbrock, initial_state, optimizer_class, optimizer_config, num_iter
69+
)
70+
return (steps[0][-1] - minimum[0]) ** 2 + (steps[1][-1] - minimum[1]) ** 2
71+
72+
73+
def plot_rastrigin(grad_iter, optimizer_name, lr):
74+
x = np.linspace(-4.5, 4.5, 250)
75+
y = np.linspace(-4.5, 4.5, 250)
76+
minimum = (0, 0)
77+
78+
X, Y = np.meshgrid(x, y)
79+
Z = rastrigin([X, Y], lib=np)
80+
81+
iter_x, iter_y = grad_iter[0, :], grad_iter[1, :]
82+
83+
fig = plt.figure(figsize=(8, 8))
84+
85+
ax = fig.add_subplot(1, 1, 1)
86+
ax.contour(X, Y, Z, 20, cmap='jet')
87+
ax.plot(iter_x, iter_y, color='r', marker='x')
88+
ax.set_title(
89+
f'Rastrigin func: {optimizer_name} with '
90+
f'{len(iter_x)} iterations, lr={lr:.6}'
91+
)
92+
plt.plot(*minimum, 'gD')
93+
plt.plot(iter_x[-1], iter_y[-1], 'rD')
94+
plt.savefig(f'rastrigin_{optimizer_name}.png')
95+
96+
97+
def plot_rosenbrok(grad_iter, optimizer_name, lr):
98+
x = np.linspace(-2, 2, 250)
99+
y = np.linspace(-1, 3, 250)
100+
minimum = (1.0, 1.0)
101+
102+
X, Y = np.meshgrid(x, y)
103+
Z = rosenbrock([X, Y])
104+
105+
iter_x, iter_y = grad_iter[0, :], grad_iter[1, :]
106+
107+
fig = plt.figure(figsize=(8, 8))
108+
109+
ax = fig.add_subplot(1, 1, 1)
110+
ax.contour(X, Y, Z, 90, cmap='jet')
111+
ax.plot(iter_x, iter_y, color='r', marker='x')
112+
113+
ax.set_title(
114+
f'Rosenbrock func: {optimizer_name} with {len(iter_x)} '
115+
f'iterations, lr={lr:.6}'
116+
)
117+
plt.plot(*minimum, 'gD')
118+
plt.plot(iter_x[-1], iter_y[-1], 'rD')
119+
plt.savefig(f'rosenbrock_{optimizer_name}.png')
120+
121+
122+
def execute_experiments(
123+
optimizers, objective, func, plot_func, initial_state, seed=1
124+
):
125+
seed = seed
126+
for item in optimizers:
127+
optimizer_class, lr_low, lr_hi = item
128+
space = {
129+
'optimizer_class': hp.choice('optimizer_class', [optimizer_class]),
130+
'lr': hp.loguniform('lr', lr_low, lr_hi),
131+
}
132+
best = fmin(
133+
fn=objective,
134+
space=space,
135+
algo=tpe.suggest,
136+
max_evals=200,
137+
rstate=np.random.RandomState(seed),
138+
)
139+
print(best['lr'], optimizer_class)
140+
141+
steps = execute_steps(
142+
func,
143+
initial_state,
144+
optimizer_class,
145+
{'lr': best['lr']},
146+
num_iter=500,
147+
)
148+
plot_func(steps, optimizer_class.__name__, best['lr'])
149+
150+
151+
if __name__ == '__main__':
152+
# python examples/viz_optimizers.py
153+
154+
# Each optimizer has tweaked search space to produce better plots and
155+
# help to converge on better lr faster.
156+
optimizers = [
157+
(optim.AccSGD, -8, -0.1),
158+
(optim.AdaBound, -8, 0.7),
159+
(optim.AdaMod, -8, 1.2),
160+
(optim.DiffGrad, -8, 0.7),
161+
(optim.Lamb, -8, 0.7),
162+
(optim.NovoGrad, -6, -2.0),
163+
(optim.RAdam, -8, 0.7),
164+
(optim.SGDW, -8, -0.9),
165+
(optim.Yogi, -8, 0.1),
166+
]
167+
execute_experiments(
168+
optimizers, objective_rastrigin, rastrigin, plot_rastrigin, (-2.0, 3.5)
169+
)
170+
171+
optimizers = [
172+
(optim.AccSGD, -8, -0.1),
173+
(optim.AdaBound, -8, 0.7),
174+
(optim.AdaMod, -4, 1.0),
175+
(optim.DiffGrad, -8, 0.2),
176+
(optim.Lamb, -8, -0.5),
177+
(optim.NovoGrad, -8, -1.0),
178+
(optim.RAdam, -8, 0.7),
179+
(optim.SGDW, -8, 0.7),
180+
(optim.Yogi, -8, 0.1),
181+
]
182+
execute_experiments(
183+
optimizers,
184+
objective_rosenbrok,
185+
rosenbrock,
186+
plot_rosenbrok,
187+
(-2.0, 2.0),
188+
)

0 commit comments

Comments
 (0)