|
15 | 15 | from torchvision.datasets import ImageFolder |
16 | 16 | from torchvision.utils import make_grid, save_image |
17 | 17 |
|
18 | | -from maskbit_pytorch.maskbit import BQVAE |
| 18 | +from maskbit_pytorch.maskbit import BQVAE, MaskBit |
19 | 19 |
|
20 | 20 | from einops import rearrange |
21 | 21 |
|
@@ -357,9 +357,9 @@ def train_step(self): |
357 | 357 |
|
358 | 358 | self.discr_optim.step() |
359 | 359 |
|
360 | | - # log |
| 360 | + # log |
361 | 361 |
|
362 | | - self.print(f"{steps}: vae loss: {logs['loss']:.3f} - discr loss: {logs['discr_loss']:.3f}") |
| 362 | + self.print(f"{steps}: vae loss: {logs['loss']:.3f} - discr loss: {logs['discr_loss']:.3f}") |
363 | 363 |
|
364 | 364 | # update exponential moving averaged generator |
365 | 365 |
|
@@ -424,6 +424,207 @@ def forward(self): |
424 | 424 | # maskbit trainer |
425 | 425 |
|
426 | 426 | class MaskBitTrainer(Module): |
427 | | - def __init__(self): |
| 427 | + def __init__( |
| 428 | + self, |
| 429 | + maskbit: MaskBit, |
| 430 | + folder, |
| 431 | + num_train_steps, |
| 432 | + batch_size, |
| 433 | + image_size, |
| 434 | + lr = 3e-4, |
| 435 | + grad_accum_every = 1, |
| 436 | + max_grad_norm = None, |
| 437 | + save_results_every = 100, |
| 438 | + save_model_every = 1000, |
| 439 | + results_folder = './results', |
| 440 | + valid_frac = 0.05, |
| 441 | + random_split_seed = 42, |
| 442 | + accelerate_kwargs: dict = dict() |
| 443 | + ): |
428 | 444 | super().__init__() |
429 | | - raise NotImplementedError |
| 445 | + |
| 446 | + # instantiate accelerator |
| 447 | + |
| 448 | + kwargs_handlers = accelerate_kwargs.get('kwargs_handlers', []) |
| 449 | + |
| 450 | + ddp_kwargs = find_and_pop( |
| 451 | + kwargs_handlers, |
| 452 | + lambda x: isinstance(x, DistributedDataParallelKwargs), |
| 453 | + partial(DistributedDataParallelKwargs, find_unused_parameters = True) |
| 454 | + ) |
| 455 | + |
| 456 | + ddp_kwargs.find_unused_parameters = True |
| 457 | + kwargs_handlers.append(ddp_kwargs) |
| 458 | + accelerate_kwargs.update(kwargs_handlers = kwargs_handlers) |
| 459 | + |
| 460 | + self.accelerator = Accelerator(**accelerate_kwargs) |
| 461 | + |
| 462 | + # training params |
| 463 | + |
| 464 | + self.register_buffer('steps', tensor(0)) |
| 465 | + |
| 466 | + self.num_train_steps = num_train_steps |
| 467 | + self.batch_size = batch_size |
| 468 | + self.grad_accum_every = grad_accum_every |
| 469 | + |
| 470 | + # model |
| 471 | + |
| 472 | + self.maskbit = maskbit |
| 473 | + |
| 474 | + # optimizers |
| 475 | + |
| 476 | + self.optim = Adam(maskbit.parameters(), lr = lr) |
| 477 | + |
| 478 | + self.max_grad_norm = max_grad_norm |
| 479 | + |
| 480 | + # create dataset |
| 481 | + |
| 482 | + self.ds = ImageDataset(folder, image_size) |
| 483 | + |
| 484 | + # split for validation |
| 485 | + |
| 486 | + if valid_frac > 0: |
| 487 | + train_size = int((1 - valid_frac) * len(self.ds)) |
| 488 | + valid_size = len(self.ds) - train_size |
| 489 | + self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed)) |
| 490 | + self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples') |
| 491 | + else: |
| 492 | + self.valid_ds = self.ds |
| 493 | + self.print(f'training with shared training and valid dataset of {len(self.ds)} samples') |
| 494 | + |
| 495 | + # dataloader |
| 496 | + |
| 497 | + self.dl = DataLoader( |
| 498 | + self.ds, |
| 499 | + batch_size = batch_size, |
| 500 | + shuffle = True |
| 501 | + ) |
| 502 | + |
| 503 | + self.valid_dl = DataLoader( |
| 504 | + self.valid_ds, |
| 505 | + batch_size = batch_size, |
| 506 | + shuffle = True |
| 507 | + ) |
| 508 | + |
| 509 | + # prepare with accelerator |
| 510 | + |
| 511 | + ( |
| 512 | + self.maskbit, |
| 513 | + self.optim, |
| 514 | + self.dl, |
| 515 | + self.valid_dl |
| 516 | + ) = self.accelerator.prepare( |
| 517 | + self.maskbit, |
| 518 | + self.optim, |
| 519 | + self.dl, |
| 520 | + self.valid_dl |
| 521 | + ) |
| 522 | + |
| 523 | + self.dl_iter = cycle(self.dl) |
| 524 | + self.valid_dl_iter = cycle(self.valid_dl) |
| 525 | + |
| 526 | + self.save_model_every = save_model_every |
| 527 | + self.save_results_every = save_results_every |
| 528 | + |
| 529 | + self.results_folder = Path(results_folder) |
| 530 | + |
| 531 | + if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'): |
| 532 | + rmtree(str(self.results_folder)) |
| 533 | + |
| 534 | + self.results_folder.mkdir(parents = True, exist_ok = True) |
| 535 | + |
| 536 | + def save(self, path): |
| 537 | + if not self.accelerator.is_local_main_process: |
| 538 | + return |
| 539 | + |
| 540 | + pkg = dict( |
| 541 | + model = self.accelerator.get_state_dict(self.maskbit), |
| 542 | + optim = self.optim.state_dict(), |
| 543 | + ) |
| 544 | + |
| 545 | + torch.save(pkg, path) |
| 546 | + |
| 547 | + def load(self, path): |
| 548 | + path = Path(path) |
| 549 | + assert path.exists() |
| 550 | + pkg = torch.load(path) |
| 551 | + |
| 552 | + maskbit = self.accelerator.unwrap_model(self.maskbit) |
| 553 | + maskbit.load_state_dict(pkg['model']) |
| 554 | + |
| 555 | + self.optim.load_state_dict(pkg['optim']) |
| 556 | + |
| 557 | + def print(self, msg): |
| 558 | + self.accelerator.print(msg) |
| 559 | + |
| 560 | + @property |
| 561 | + def device(self): |
| 562 | + return self.accelerator.device |
| 563 | + |
| 564 | + @property |
| 565 | + def is_distributed(self): |
| 566 | + return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1) |
| 567 | + |
| 568 | + @property |
| 569 | + def is_main(self): |
| 570 | + return self.accelerator.is_main_process |
| 571 | + |
| 572 | + @property |
| 573 | + def is_local_main(self): |
| 574 | + return self.accelerator.is_local_main_process |
| 575 | + |
| 576 | + def train_step(self): |
| 577 | + acc = self.accelerator |
| 578 | + device = self.device |
| 579 | + |
| 580 | + steps = int(self.steps.item()) |
| 581 | + |
| 582 | + self.maskbit.train() |
| 583 | + |
| 584 | + # logs |
| 585 | + |
| 586 | + logs = dict() |
| 587 | + |
| 588 | + # update vae (generator) |
| 589 | + |
| 590 | + for _ in range(self.grad_accum_every): |
| 591 | + img = next(self.dl_iter) |
| 592 | + img = img.to(device) |
| 593 | + |
| 594 | + with acc.autocast(): |
| 595 | + loss = self.maskbit(img) |
| 596 | + |
| 597 | + acc.backward(loss / self.grad_accum_every) |
| 598 | + |
| 599 | + accum_log(logs, {'loss': loss.item() / self.grad_accum_every}) |
| 600 | + |
| 601 | + if exists(self.max_grad_norm): |
| 602 | + acc.clip_grad_norm_(self.maskbit.parameters(), self.max_grad_norm) |
| 603 | + |
| 604 | + self.optim.step() |
| 605 | + self.optim.zero_grad() |
| 606 | + |
| 607 | + # log |
| 608 | + |
| 609 | + self.print(f"{steps}: maskbit loss: {logs['loss']:.3f}") |
| 610 | + |
| 611 | + # save model every so often |
| 612 | + |
| 613 | + acc.wait_for_everyone() |
| 614 | + |
| 615 | + if self.is_main and not (steps % self.save_model_every): |
| 616 | + state_dict = acc.unwrap_model(self.maskbit).state_dict() |
| 617 | + model_path = str(self.results_folder / f'maskbit.{steps}.pt') |
| 618 | + acc.save(state_dict, model_path) |
| 619 | + |
| 620 | + self.print(f'{steps}: saving model to {str(self.results_folder)}') |
| 621 | + |
| 622 | + self.steps += 1 |
| 623 | + return logs |
| 624 | + |
| 625 | + def forward(self): |
| 626 | + |
| 627 | + while self.steps < self.num_train_steps: |
| 628 | + logs = self.train_step() |
| 629 | + |
| 630 | + self.print('training complete') |
0 commit comments