|
| 1 | +# pro_gan_pytorch |
| 2 | +Package contains implementation of ProGAN. |
| 3 | +Paper titled "Progressive growing of GANs for improved |
| 4 | +Quality, Stability, and Variation". <br> |
| 5 | +link -> https://arxiv.org/abs/1710.10196 |
| 6 | + |
| 7 | +# Steps to use: |
| 8 | +1.) Install your appropriate version of PyTorch. |
| 9 | +The torch dependency in this package uses the most basic |
| 10 | +"cpu" version. follow instructions on |
| 11 | +<a href="http://pytorch.org/"> http://pytorch.org </a> to |
| 12 | +install the "gpu" version of PyTorch.<br> |
| 13 | + |
| 14 | +2.) Install this package using pip: |
| 15 | + |
| 16 | + $ workon [your virtual environment] |
| 17 | + $ pip install pro-gan-pth |
| 18 | + |
| 19 | +3.) In your code: |
| 20 | + |
| 21 | + import pytorch_pro_gan.PRO_GAN as pg |
| 22 | + |
| 23 | + Use the modules `pg.Generator`, `pg.Discriminator` and |
| 24 | + `pg.ProGAN`. |
| 25 | + |
| 26 | + Help on class ProGAN in module pro_gan_pytorch.PRO_GAN: |
| 27 | + |
| 28 | + class ProGAN(builtins.object) |
| 29 | + | Wrapper around the Generator and the Discriminator |
| 30 | + | |
| 31 | + | Methods defined here: |
| 32 | + | |
| 33 | + | __init__(self, depth=7, latent_size=64, learning_rate=0.001, beta_1=0, beta_2=0.99, eps=1e-08, drift=0.001, n_critic=1, device=device(type='cpu')) |
| 34 | + | constructor for the class |
| 35 | + | :param depth: depth of the GAN (will be used for each generator and discriminator) |
| 36 | + | :param latent_size: latent size of the manifold used by the GAN |
| 37 | + | :param learning_rate: learning rate for Adam |
| 38 | + | :param beta_1: beta_1 for Adam |
| 39 | + | :param beta_2: beta_2 for Adam |
| 40 | + | :param eps: epsilon for Adam |
| 41 | + | :param n_critic: number of times to update discriminator |
| 42 | + | :param device: device to run the GAN on (GPU / CPU) |
| 43 | + | |
| 44 | + | optimize_discriminator(self, noise, real_batch, depth, alpha) |
| 45 | + | performs one step of weight update on discriminator using the batch of data |
| 46 | + | :param noise: input noise of sample generation |
| 47 | + | :param real_batch: real samples batch |
| 48 | + | :param depth: current depth of optimization |
| 49 | + | :param alpha: current alpha for fade-in |
| 50 | + | :return: current loss (Wasserstein loss) |
| 51 | + | |
| 52 | + | optimize_generator(self, noise, depth, alpha) |
| 53 | + | performs one step of weight update on generator for the given batch_size |
| 54 | + | :param noise: input random noise required for generating samples |
| 55 | + | :param depth: depth of the network at which optimization is done |
| 56 | + | :param alpha: value of alpha for fade-in effect |
| 57 | + | :return: current loss (Wasserstein estimate) |
| 58 | + | |
| 59 | + | ---------------------------------------------------------------------- |
| 60 | + | Data descriptors defined here: |
| 61 | + | |
| 62 | + | __dict__ |
| 63 | + | dictionary for instance variables (if defined) |
| 64 | + | |
| 65 | + | __weakref__ |
| 66 | + | list of weak references to the object (if defined) |
0 commit comments