Skip to content
This repository was archived by the owner on Aug 28, 2025. It is now read-only.

Commit eeb2a30

Browse files
rohitgr7Borda
andauthored
Update PTL examples for v1.6 release (#146)
* update pytorch lightning examples for next release * remove strategy * use torchmetrics * improvements * use estimated stepping batches * use auto devices * fix examples * Apply suggestions from code review * limit 1 GPU Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jirka <jirka.borovec@seznam.cz>
1 parent 1025d51 commit eeb2a30

File tree

9 files changed

+136
-110
lines changed

9 files changed

+136
-110
lines changed

lightning_examples/augmentation_kornia/augmentation.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,13 @@
1111
from kornia import image_to_tensor, tensor_to_image
1212
from kornia.augmentation import ColorJitter, RandomChannelShuffle, RandomHorizontalFlip, RandomThinPlateSpline
1313
from pytorch_lightning import LightningModule, Trainer
14+
from pytorch_lightning.callbacks.progress import TQDMProgressBar
1415
from pytorch_lightning.loggers import CSVLogger
1516
from torch import Tensor
1617
from torch.nn import functional as F
1718
from torch.utils.data import DataLoader
1819
from torchvision.datasets import CIFAR10
1920

20-
AVAIL_GPUS = min(1, torch.cuda.device_count())
21-
2221
# %% [markdown]
2322
# ## Define Data Augmentations module
2423
#
@@ -106,10 +105,11 @@ def __init__(self):
106105

107106
self.transform = DataAugmentation() # per batch augmentation_kornia
108107

109-
self.accuracy = torchmetrics.Accuracy()
108+
self.train_accuracy = torchmetrics.Accuracy()
109+
self.val_accuracy = torchmetrics.Accuracy()
110110

111111
def forward(self, x):
112-
return F.softmax(self.model(x))
112+
return self.model(x)
113113

114114
def compute_loss(self, y_hat, y):
115115
return F.cross_entropy(y_hat, y)
@@ -127,21 +127,28 @@ def _to_vis(data):
127127
plt.figure(figsize=win_size)
128128
plt.imshow(_to_vis(imgs_aug))
129129

130+
def on_after_batch_transfer(self, batch, dataloader_idx):
131+
x, y = batch
132+
if self.trainer.training:
133+
x = self.transform(x) # => we perform GPU/Batched data augmentation
134+
return x, y
135+
130136
def training_step(self, batch, batch_idx):
131137
x, y = batch
132-
x_aug = self.transform(x) # => we perform GPU/Batched data augmentation
133-
y_hat = self(x_aug)
138+
y_hat = self(x)
134139
loss = self.compute_loss(y_hat, y)
140+
self.train_accuracy.update(y_hat, y)
135141
self.log("train_loss", loss, prog_bar=False)
136-
self.log("train_acc", self.accuracy(y_hat, y), prog_bar=False)
142+
self.log("train_acc", self.train_accuracy, prog_bar=False)
137143
return loss
138144

139145
def validation_step(self, batch, batch_idx):
140146
x, y = batch
141147
y_hat = self(x)
142148
loss = self.compute_loss(y_hat, y)
149+
self.val_accuracy.update(y_hat, y)
143150
self.log("valid_loss", loss, prog_bar=False)
144-
self.log("valid_acc", self.accuracy(y_hat, y), prog_bar=True)
151+
self.log("valid_acc", self.val_accuracy, prog_bar=True)
145152

146153
def configure_optimizers(self):
147154
optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-4)
@@ -158,7 +165,7 @@ def train_dataloader(self):
158165
return loader
159166

160167
def val_dataloader(self):
161-
dataset = CIFAR10(os.getcwd(), train=True, download=True, transform=self.preprocess)
168+
dataset = CIFAR10(os.getcwd(), train=False, download=True, transform=self.preprocess)
162169
loader = DataLoader(dataset, batch_size=32)
163170
return loader
164171

@@ -179,8 +186,9 @@ def val_dataloader(self):
179186
# %%
180187
# Initialize a trainer
181188
trainer = Trainer(
182-
progress_bar_refresh_rate=20,
183-
gpus=AVAIL_GPUS,
189+
callbacks=[TQDMProgressBar(refresh_rate=20)],
190+
accelerator="auto",
191+
devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
184192
max_epochs=10,
185193
logger=CSVLogger(save_dir="logs/", name="cifar10-resnet18"),
186194
)

lightning_examples/barlow-twins/barlow_twins.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -283,15 +283,12 @@ def shared_step(self, batch):
283283

284284
def training_step(self, batch, batch_idx):
285285
loss = self.shared_step(batch)
286-
287-
self.log("train_loss", loss.item(), on_step=True, on_epoch=False)
286+
self.log("train_loss", loss, on_step=True, on_epoch=False)
288287
return loss
289288

290289
def validation_step(self, batch, batch_idx):
291290
loss = self.shared_step(batch)
292-
293291
self.log("val_loss", loss, on_step=False, on_epoch=True)
294-
return loss
295292

296293
def configure_optimizers(self):
297294
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
@@ -329,7 +326,7 @@ def __init__(
329326
self.encoder_output_dim = encoder_output_dim
330327
self.num_classes = num_classes
331328

332-
def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
329+
def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
333330

334331
# add linear_eval layer and optimizer
335332
pl_module.online_finetuner = nn.Linear(self.encoder_output_dim, self.num_classes).to(pl_module.device)
@@ -408,12 +405,12 @@ def on_validation_batch_end(
408405
)
409406

410407
online_finetuner = OnlineFineTuner(encoder_output_dim=encoder_out_dim, num_classes=10)
411-
checkpoint_callback = ModelCheckpoint(every_n_val_epochs=100, save_top_k=-1, save_last=True)
408+
checkpoint_callback = ModelCheckpoint(every_n_epochs=100, save_top_k=-1, save_last=True)
412409

413410
trainer = Trainer(
414411
max_epochs=max_epochs,
415-
gpus=torch.cuda.device_count(),
416-
precision=16 if torch.cuda.device_count() > 0 else 32,
412+
accelerator="auto",
413+
devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
417414
callbacks=[online_finetuner, checkpoint_callback],
418415
)
419416

lightning_examples/basic-gan/gan.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# %%
22
import os
3-
from collections import OrderedDict
43

54
import numpy as np
65
import torch
@@ -9,12 +8,12 @@
98
import torchvision
109
import torchvision.transforms as transforms
1110
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
11+
from pytorch_lightning.callbacks.progress import TQDMProgressBar
1212
from torch.utils.data import DataLoader, random_split
1313
from torchvision.datasets import MNIST
1414

1515
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
16-
AVAIL_GPUS = min(1, torch.cuda.device_count())
17-
BATCH_SIZE = 256 if AVAIL_GPUS else 64
16+
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
1817
NUM_WORKERS = int(os.cpu_count() / 2)
1918

2019
# %% [markdown]
@@ -205,9 +204,8 @@ def training_step(self, batch, batch_idx, optimizer_idx):
205204

206205
# adversarial loss is binary cross-entropy
207206
g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
208-
tqdm_dict = {"g_loss": g_loss}
209-
output = OrderedDict({"loss": g_loss, "progress_bar": tqdm_dict, "log": tqdm_dict})
210-
return output
207+
self.log("g_loss", g_loss, prog_bar=True)
208+
return g_loss
211209

212210
# train discriminator
213211
if optimizer_idx == 1:
@@ -227,9 +225,8 @@ def training_step(self, batch, batch_idx, optimizer_idx):
227225

228226
# discriminator loss is the average of these
229227
d_loss = (real_loss + fake_loss) / 2
230-
tqdm_dict = {"d_loss": d_loss}
231-
output = OrderedDict({"loss": d_loss, "progress_bar": tqdm_dict, "log": tqdm_dict})
232-
return output
228+
self.log("d_loss", d_loss, prog_bar=True)
229+
return d_loss
233230

234231
def configure_optimizers(self):
235232
lr = self.hparams.lr
@@ -240,7 +237,7 @@ def configure_optimizers(self):
240237
opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
241238
return [opt_g, opt_d], []
242239

243-
def on_epoch_end(self):
240+
def on_validation_epoch_end(self):
244241
z = self.validation_z.type_as(self.generator.model[0].weight)
245242

246243
# log sampled images
@@ -252,7 +249,12 @@ def on_epoch_end(self):
252249
# %%
253250
dm = MNISTDataModule()
254251
model = GAN(*dm.size())
255-
trainer = Trainer(gpus=AVAIL_GPUS, max_epochs=5, progress_bar_refresh_rate=20)
252+
trainer = Trainer(
253+
accelerator="auto",
254+
devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
255+
max_epochs=5,
256+
callbacks=[TQDMProgressBar(refresh_rate=20)],
257+
)
256258
trainer.fit(model, dm)
257259

258260
# %%

lightning_examples/cifar10-baseline/baseline.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
1414
from pytorch_lightning import LightningModule, Trainer, seed_everything
1515
from pytorch_lightning.callbacks import LearningRateMonitor
16+
from pytorch_lightning.callbacks.progress import TQDMProgressBar
1617
from pytorch_lightning.loggers import TensorBoardLogger
1718
from torch.optim.lr_scheduler import OneCycleLR
1819
from torch.optim.swa_utils import AveragedModel, update_bn
@@ -21,8 +22,7 @@
2122
seed_everything(7)
2223

2324
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
24-
AVAIL_GPUS = min(1, torch.cuda.device_count())
25-
BATCH_SIZE = 256 if AVAIL_GPUS else 64
25+
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
2626
NUM_WORKERS = int(os.cpu_count() / 2)
2727

2828
# %% [markdown]
@@ -137,14 +137,13 @@ def configure_optimizers(self):
137137

138138
# %%
139139
model = LitResnet(lr=0.05)
140-
model.datamodule = cifar10_dm
141140

142141
trainer = Trainer(
143-
progress_bar_refresh_rate=10,
144142
max_epochs=30,
145-
gpus=AVAIL_GPUS,
143+
accelerator="auto",
144+
devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
146145
logger=TensorBoardLogger("lightning_logs/", name="resnet"),
147-
callbacks=[LearningRateMonitor(logging_interval="step")],
146+
callbacks=[LearningRateMonitor(logging_interval="step"), TQDMProgressBar(refresh_rate=10)],
148147
)
149148

150149
trainer.fit(model, cifar10_dm)
@@ -189,18 +188,19 @@ def configure_optimizers(self):
189188
return optimizer
190189

191190
def on_train_end(self):
192-
update_bn(self.datamodule.train_dataloader(), self.swa_model, device=self.device)
191+
update_bn(self.trainer.datamodule.train_dataloader(), self.swa_model, device=self.device)
193192

194193

195194
# %%
196195
swa_model = SWAResnet(model.model, lr=0.01)
197196
swa_model.datamodule = cifar10_dm
198197

199198
swa_trainer = Trainer(
200-
progress_bar_refresh_rate=20,
201199
max_epochs=20,
202-
gpus=AVAIL_GPUS,
200+
accelerator="auto",
201+
devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
203202
logger=TensorBoardLogger("lightning_logs/", name="swa_resnet"),
203+
callbacks=[TQDMProgressBar(refresh_rate=20)],
204204
)
205205

206206
swa_trainer.fit(swa_model, cifar10_dm)

lightning_examples/datamodules/datamodules.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
import torch.nn.functional as F
1111
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
12+
from pytorch_lightning.callbacks.progress import TQDMProgressBar
1213
from torch import nn
1314
from torch.utils.data import DataLoader, random_split
1415
from torchmetrics.functional import accuracy
@@ -18,8 +19,7 @@
1819
from torchvision.datasets import CIFAR10, MNIST
1920

2021
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
21-
AVAIL_GPUS = min(1, torch.cuda.device_count())
22-
BATCH_SIZE = 256 if AVAIL_GPUS else 64
22+
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
2323

2424
# %% [markdown]
2525
# ### Defining the LitMNISTModel
@@ -84,7 +84,6 @@ def validation_step(self, batch, batch_idx):
8484
acc = accuracy(preds, y)
8585
self.log("val_loss", loss, prog_bar=True)
8686
self.log("val_acc", acc, prog_bar=True)
87-
return loss
8887

8988
def configure_optimizers(self):
9089
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
@@ -127,8 +126,9 @@ def test_dataloader(self):
127126
model = LitMNIST()
128127
trainer = Trainer(
129128
max_epochs=2,
130-
gpus=AVAIL_GPUS,
131-
progress_bar_refresh_rate=20,
129+
accelerator="auto",
130+
devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
131+
callbacks=[TQDMProgressBar(refresh_rate=20)],
132132
)
133133
trainer.fit(model)
134134

@@ -252,15 +252,13 @@ def training_step(self, batch, batch_idx):
252252
return loss
253253

254254
def validation_step(self, batch, batch_idx):
255-
256255
x, y = batch
257256
logits = self(x)
258257
loss = F.nll_loss(logits, y)
259258
preds = torch.argmax(logits, dim=1)
260259
acc = accuracy(preds, y)
261260
self.log("val_loss", loss, prog_bar=True)
262261
self.log("val_acc", acc, prog_bar=True)
263-
return loss
264262

265263
def configure_optimizers(self):
266264
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
@@ -280,8 +278,9 @@ def configure_optimizers(self):
280278
# Init trainer
281279
trainer = Trainer(
282280
max_epochs=3,
283-
progress_bar_refresh_rate=20,
284-
gpus=AVAIL_GPUS,
281+
callbacks=[TQDMProgressBar(refresh_rate=20)],
282+
accelerator="auto",
283+
devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
285284
)
286285
# Pass the datamodule as arg to trainer.fit to override model hooks :)
287286
trainer.fit(model, dm)
@@ -343,9 +342,11 @@ def test_dataloader(self):
343342
# %%
344343
dm = CIFAR10DataModule()
345344
model = LitModel(*dm.size(), dm.num_classes, hidden_size=256)
345+
tqdm_progress_bar = TQDMProgressBar(refresh_rate=20)
346346
trainer = Trainer(
347347
max_epochs=5,
348-
progress_bar_refresh_rate=20,
349-
gpus=AVAIL_GPUS,
348+
accelerator="auto",
349+
devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
350+
callbacks=[tqdm_progress_bar],
350351
)
351352
trainer.fit(model, dm)

0 commit comments

Comments
 (0)