Skip to content

Commit eaca2ee

Browse files
committed
flesh out the maskbit trainer
1 parent 74ec85e commit eaca2ee

File tree

3 files changed

+210
-6
lines changed

3 files changed

+210
-6
lines changed

maskbit_pytorch/maskbit.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,9 @@ def __init__(
546546

547547
self._c = vae.channels
548548

549+
def parameters(self):
550+
return self.demasking_transformer.parameters()
551+
549552
@property
550553
def device(self):
551554
return next(self.parameters()).device

maskbit_pytorch/trainer.py

Lines changed: 206 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from torchvision.datasets import ImageFolder
1616
from torchvision.utils import make_grid, save_image
1717

18-
from maskbit_pytorch.maskbit import BQVAE
18+
from maskbit_pytorch.maskbit import BQVAE, MaskBit
1919

2020
from einops import rearrange
2121

@@ -357,9 +357,9 @@ def train_step(self):
357357

358358
self.discr_optim.step()
359359

360-
# log
360+
# log
361361

362-
self.print(f"{steps}: vae loss: {logs['loss']:.3f} - discr loss: {logs['discr_loss']:.3f}")
362+
self.print(f"{steps}: vae loss: {logs['loss']:.3f} - discr loss: {logs['discr_loss']:.3f}")
363363

364364
# update exponential moving averaged generator
365365

@@ -424,6 +424,207 @@ def forward(self):
424424
# maskbit trainer
425425

426426
class MaskBitTrainer(Module):
427-
def __init__(self):
427+
def __init__(
428+
self,
429+
maskbit: MaskBit,
430+
folder,
431+
num_train_steps,
432+
batch_size,
433+
image_size,
434+
lr = 3e-4,
435+
grad_accum_every = 1,
436+
max_grad_norm = None,
437+
save_results_every = 100,
438+
save_model_every = 1000,
439+
results_folder = './results',
440+
valid_frac = 0.05,
441+
random_split_seed = 42,
442+
accelerate_kwargs: dict = dict()
443+
):
428444
super().__init__()
429-
raise NotImplementedError
445+
446+
# instantiate accelerator
447+
448+
kwargs_handlers = accelerate_kwargs.get('kwargs_handlers', [])
449+
450+
ddp_kwargs = find_and_pop(
451+
kwargs_handlers,
452+
lambda x: isinstance(x, DistributedDataParallelKwargs),
453+
partial(DistributedDataParallelKwargs, find_unused_parameters = True)
454+
)
455+
456+
ddp_kwargs.find_unused_parameters = True
457+
kwargs_handlers.append(ddp_kwargs)
458+
accelerate_kwargs.update(kwargs_handlers = kwargs_handlers)
459+
460+
self.accelerator = Accelerator(**accelerate_kwargs)
461+
462+
# training params
463+
464+
self.register_buffer('steps', tensor(0))
465+
466+
self.num_train_steps = num_train_steps
467+
self.batch_size = batch_size
468+
self.grad_accum_every = grad_accum_every
469+
470+
# model
471+
472+
self.maskbit = maskbit
473+
474+
# optimizers
475+
476+
self.optim = Adam(maskbit.parameters(), lr = lr)
477+
478+
self.max_grad_norm = max_grad_norm
479+
480+
# create dataset
481+
482+
self.ds = ImageDataset(folder, image_size)
483+
484+
# split for validation
485+
486+
if valid_frac > 0:
487+
train_size = int((1 - valid_frac) * len(self.ds))
488+
valid_size = len(self.ds) - train_size
489+
self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
490+
self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
491+
else:
492+
self.valid_ds = self.ds
493+
self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')
494+
495+
# dataloader
496+
497+
self.dl = DataLoader(
498+
self.ds,
499+
batch_size = batch_size,
500+
shuffle = True
501+
)
502+
503+
self.valid_dl = DataLoader(
504+
self.valid_ds,
505+
batch_size = batch_size,
506+
shuffle = True
507+
)
508+
509+
# prepare with accelerator
510+
511+
(
512+
self.maskbit,
513+
self.optim,
514+
self.dl,
515+
self.valid_dl
516+
) = self.accelerator.prepare(
517+
self.maskbit,
518+
self.optim,
519+
self.dl,
520+
self.valid_dl
521+
)
522+
523+
self.dl_iter = cycle(self.dl)
524+
self.valid_dl_iter = cycle(self.valid_dl)
525+
526+
self.save_model_every = save_model_every
527+
self.save_results_every = save_results_every
528+
529+
self.results_folder = Path(results_folder)
530+
531+
if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
532+
rmtree(str(self.results_folder))
533+
534+
self.results_folder.mkdir(parents = True, exist_ok = True)
535+
536+
def save(self, path):
537+
if not self.accelerator.is_local_main_process:
538+
return
539+
540+
pkg = dict(
541+
model = self.accelerator.get_state_dict(self.maskbit),
542+
optim = self.optim.state_dict(),
543+
)
544+
545+
torch.save(pkg, path)
546+
547+
def load(self, path):
548+
path = Path(path)
549+
assert path.exists()
550+
pkg = torch.load(path)
551+
552+
maskbit = self.accelerator.unwrap_model(self.maskbit)
553+
maskbit.load_state_dict(pkg['model'])
554+
555+
self.optim.load_state_dict(pkg['optim'])
556+
557+
def print(self, msg):
558+
self.accelerator.print(msg)
559+
560+
@property
561+
def device(self):
562+
return self.accelerator.device
563+
564+
@property
565+
def is_distributed(self):
566+
return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
567+
568+
@property
569+
def is_main(self):
570+
return self.accelerator.is_main_process
571+
572+
@property
573+
def is_local_main(self):
574+
return self.accelerator.is_local_main_process
575+
576+
def train_step(self):
577+
acc = self.accelerator
578+
device = self.device
579+
580+
steps = int(self.steps.item())
581+
582+
self.maskbit.train()
583+
584+
# logs
585+
586+
logs = dict()
587+
588+
# update vae (generator)
589+
590+
for _ in range(self.grad_accum_every):
591+
img = next(self.dl_iter)
592+
img = img.to(device)
593+
594+
with acc.autocast():
595+
loss = self.maskbit(img)
596+
597+
acc.backward(loss / self.grad_accum_every)
598+
599+
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
600+
601+
if exists(self.max_grad_norm):
602+
acc.clip_grad_norm_(self.maskbit.parameters(), self.max_grad_norm)
603+
604+
self.optim.step()
605+
self.optim.zero_grad()
606+
607+
# log
608+
609+
self.print(f"{steps}: maskbit loss: {logs['loss']:.3f}")
610+
611+
# save model every so often
612+
613+
acc.wait_for_everyone()
614+
615+
if self.is_main and not (steps % self.save_model_every):
616+
state_dict = acc.unwrap_model(self.maskbit).state_dict()
617+
model_path = str(self.results_folder / f'maskbit.{steps}.pt')
618+
acc.save(state_dict, model_path)
619+
620+
self.print(f'{steps}: saving model to {str(self.results_folder)}')
621+
622+
self.steps += 1
623+
return logs
624+
625+
def forward(self):
626+
627+
while self.steps < self.num_train_steps:
628+
logs = self.train_step()
629+
630+
self.print('training complete')

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "maskbit-pytorch"
3-
version = "0.0.1"
3+
version = "0.0.2"
44
description = "MaskBit"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)