Skip to content

Commit 4c090f1

Browse files
authored
Merge pull request #50 from kashif/example
initial autoencoder example
2 parents 936d9be + 615747c commit 4c090f1

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed

examples/autoencoder.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# FashionMnist VQ experiment with various settings.
2+
# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py
3+
4+
from tqdm.auto import trange
5+
6+
import torch
7+
import torch.nn as nn
8+
from torchvision import datasets, transforms
9+
from torch.utils.data import DataLoader
10+
11+
from vector_quantize_pytorch import VectorQuantize
12+
13+
14+
lr = 3e-4
15+
train_iter = 1000
16+
num_codes = 256
17+
seed = 1234
18+
device = "cuda" if torch.cuda.is_available() else "cpu"
19+
20+
21+
class SimpleVQAutoEncoder(nn.Module):
22+
def __init__(self, **vq_kwargs):
23+
super().__init__()
24+
self.layers = nn.ModuleList(
25+
[
26+
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
27+
nn.MaxPool2d(kernel_size=2, stride=2),
28+
nn.GELU(),
29+
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
30+
nn.MaxPool2d(kernel_size=2, stride=2),
31+
VectorQuantize(dim=32, **vq_kwargs),
32+
nn.Upsample(scale_factor=2, mode="nearest"),
33+
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
34+
nn.GELU(),
35+
nn.Upsample(scale_factor=2, mode="nearest"),
36+
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
37+
]
38+
)
39+
return
40+
41+
def forward(self, x):
42+
for layer in self.layers:
43+
if isinstance(layer, VectorQuantize):
44+
x_shape = x.shape[:-1]
45+
x_flat = x.view(x.size(0), -1, x.size(1))
46+
x_flat, indices, commit_loss = layer(x_flat)
47+
x = x_flat.view(*x_shape, -1)
48+
else:
49+
x = layer(x)
50+
return x.clamp(-1, 1), indices, commit_loss
51+
52+
53+
def train(model, train_loader, train_iterations=1000, alpha=10):
54+
def iterate_dataset(data_loader):
55+
data_iter = iter(data_loader)
56+
while True:
57+
try:
58+
x, y = next(data_iter)
59+
except StopIteration:
60+
data_iter = iter(data_loader)
61+
x, y = next(data_iter)
62+
yield x.to(device), y.to(device)
63+
64+
for _ in (pbar := trange(train_iterations)):
65+
opt.zero_grad()
66+
x, _ = next(iterate_dataset(train_loader))
67+
out, indices, cmt_loss = model(x)
68+
rec_loss = (out - x).abs().mean()
69+
(rec_loss + alpha * cmt_loss).backward()
70+
71+
opt.step()
72+
pbar.set_description(
73+
f"rec loss: {rec_loss.item():.3f} | "
74+
+ f"cmt loss: {cmt_loss.item():.3f} | "
75+
+ f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
76+
)
77+
return
78+
79+
80+
transform = transforms.Compose(
81+
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
82+
)
83+
train_dataset = DataLoader(
84+
datasets.FashionMNIST(
85+
root="~/data/fashion_mnist", train=True, download=True, transform=transform
86+
),
87+
batch_size=256,
88+
shuffle=True,
89+
)
90+
91+
print("baseline")
92+
torch.random.manual_seed(seed)
93+
model = SimpleVQAutoEncoder(codebook_size=num_codes).to(device)
94+
opt = torch.optim.AdamW(model.parameters(), lr=lr)
95+
train(model, train_dataset, train_iterations=train_iter)

0 commit comments

Comments
 (0)