diff --git a/ddpm_conditional.py b/ddpm_conditional.py index d2e0cecff..0fc104865 100644 --- a/ddpm_conditional.py +++ b/ddpm_conditional.py @@ -98,7 +98,7 @@ def train(args): logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i) if epoch % 10 == 0: - labels = torch.arange(10).long().to(device) + labels = torch.arange(args.num_classes).long().to(device) sampled_images = diffusion.sample(model, n=len(labels), labels=labels) ema_sampled_images = diffusion.sample(ema_model, n=len(labels), labels=labels) plot_images(sampled_images)