33import os
44import torch
55import argparse
6+ import pprint
67from PIL import Image
78from torch .autograd import Variable
89import torchvision .utils as vutils
2324parser .add_argument ('--decoder2' , default = 'models/vgg19_normalized_decoder2.pth.tar' , help = 'Path to the decoder2' )
2425parser .add_argument ('--decoder1' , default = 'models/vgg19_normalized_decoder1.pth.tar' , help = 'Path to the decoder1' )
2526parser .add_argument ('--cuda' , action = 'store_true' , help = 'enables cuda' )
27+ parser .add_argument ('--transform-method' , choices = ['original' , 'closed-form' ], default = 'original' ,
28+ help = ('How to whiten and color the features. "original" for the formulation of Li et al. ( https://arxiv.org/abs/1705.08086 ) '
29+ 'or "closed-form" for method of Lu et al. ( https://arxiv.org/abs/1906.00668 ' ))
2630parser .add_argument ('--batch_size' , type = int , default = 1 , help = 'batch size' )
2731parser .add_argument ('--fineSize' , type = int , default = 512 , help = 'resize image to fineSize x fineSize,leave it to 0 if not resize' )
2832parser .add_argument ('--outf' , default = 'samples/' , help = 'folder to output images' )
3236parser .add_argument ('--gpu' , type = int , default = 0 , help = "which gpu to run on. default is 0" )
3337
3438args = parser .parse_args ()
39+ pprint .pprint (args .__dict__ , indent = 2 )
3540
3641try :
3742 os .makedirs (args .outf )
4449 batch_size = 1 ,
4550 shuffle = False )
4651
47- def styleTransfer (wct , targets , contentImg , styleImg , imname , gamma , delta , outf ):
52+ def styleTransfer (wct , targets , contentImg , styleImg , imname , gamma , delta , outf , transform_method ):
4853
4954 current_result = contentImg
5055 eIorigs = [f .cpu ().squeeze (0 ) for f in wct .encoder (contentImg , targets )]
@@ -58,8 +63,8 @@ def styleTransfer(wct, targets, contentImg, styleImg, imname, gamma, delta, outf
5863 else :
5964 eIlast = wct .encoder (current_result , target ).cpu ().squeeze (0 )
6065
61- CsIlast = wct .transform (eIlast , eIs ).float ()
62- CsIorig = wct .transform (eIorig , eIs ).float ()
66+ CsIlast = wct .transform (eIlast , eIs , transform_method ).float ()
67+ CsIorig = wct .transform (eIorig , eIs , transform_method ).float ()
6368
6469 decoder_input = (gamma * (delta * CsIlast + (1 - delta ) * CsIorig ) \
6570 + (1 - gamma ) * eIorig )
@@ -91,7 +96,7 @@ def main():
9196 # WCT Style Transfer
9297 targets = [f'relu{ t } _1' for t in args .targets ]
9398 styleTransfer (wct , targets , contentImg , styleImg , imname ,
94- args .gamma , args .delta , args .outf )
99+ args .gamma , args .delta , args .outf , args . transform_method )
95100 end_time = time .time ()
96101 print (' Elapsed time is: %f' % (end_time - start_time ))
97102 avgTime += (end_time - start_time )
0 commit comments