Skip to content

Commit 140d2e3

Browse files
migrate to pytorch 1.0.
use torch.no_grad() instead of volatile variables. update example images to also use relu5_1 stylization. Point models download to new folder. add .gitignore
1 parent 452846f commit 140d2e3

File tree

10 files changed

+94
-881
lines changed

10 files changed

+94
-881
lines changed

.gitignore

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
models
2+
models.zip
3+
models_pytorch.zip
4+
5+
# python artifacts
6+
__pychache__
7+
*.pyc
8+
9+
# Temporary Files
10+
*~
11+
*.swp
12+
*.swo

Readme.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Official Torch implementation can be found [here](https://github.com/Yijunmaveri
77
## Prerequisites
88
- [Pytorch](http://pytorch.org/)
99
- [torchvision](https://github.com/pytorch/vision)
10-
- Pretrained encoder and decoder [models](https://drive.google.com/file/d/1M5KBPfqrIUZqrBZf78CIxLrMUT4lD4t9/view?usp=sharing) for image reconstruction only (download and uncompress them under models/)
10+
- Pretrained encoder and decoder [models](http://pascal.inrialpes.fr/data2/archetypal_style/models_pytorch.zip) for image reconstruction only (download and uncompress them under models/)
1111
- CUDA + CuDNN
1212

1313
## Prepare images

WCT.py

Lines changed: 64 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,18 @@
88
from Loader import Dataset
99
from util import *
1010
import scipy.misc
11-
from torch.utils.serialization import load_lua
1211
import time
1312

1413
parser = argparse.ArgumentParser(description='WCT Pytorch')
1514
parser.add_argument('--contentPath',default='images/content',help='path to train')
1615
parser.add_argument('--stylePath',default='images/style',help='path to train')
1716
parser.add_argument('--workers', default=2, type=int, metavar='N',help='number of data loading workers (default: 4)')
18-
parser.add_argument('--vgg1', default='models/vgg_normalised_conv1_1.t7', help='Path to the VGG conv1_1')
19-
parser.add_argument('--vgg2', default='models/vgg_normalised_conv2_1.t7', help='Path to the VGG conv2_1')
20-
parser.add_argument('--vgg3', default='models/vgg_normalised_conv3_1.t7', help='Path to the VGG conv3_1')
21-
parser.add_argument('--vgg4', default='models/vgg_normalised_conv4_1.t7', help='Path to the VGG conv4_1')
22-
parser.add_argument('--vgg5', default='models/vgg_normalised_conv5_1.t7', help='Path to the VGG conv5_1')
23-
parser.add_argument('--decoder5', default='models/feature_invertor_conv5_1.t7', help='Path to the decoder5')
24-
parser.add_argument('--decoder4', default='models/feature_invertor_conv4_1.t7', help='Path to the decoder4')
25-
parser.add_argument('--decoder3', default='models/feature_invertor_conv3_1.t7', help='Path to the decoder3')
26-
parser.add_argument('--decoder2', default='models/feature_invertor_conv2_1.t7', help='Path to the decoder2')
27-
parser.add_argument('--decoder1', default='models/feature_invertor_conv1_1.t7', help='Path to the decoder1')
17+
parser.add_argument('--encoder', default='models/vgg19_normalized.pth.tar', help='Path to the VGG conv1_1')
18+
parser.add_argument('--decoder5', default='models/vgg19_normalized_decoder5.pth.tar', help='Path to the decoder5')
19+
parser.add_argument('--decoder4', default='models/vgg19_normalized_decoder4.pth.tar', help='Path to the decoder4')
20+
parser.add_argument('--decoder3', default='models/vgg19_normalized_decoder3.pth.tar', help='Path to the decoder3')
21+
parser.add_argument('--decoder2', default='models/vgg19_normalized_decoder2.pth.tar', help='Path to the decoder2')
22+
parser.add_argument('--decoder1', default='models/vgg19_normalized_decoder1.pth.tar', help='Path to the decoder1')
2823
parser.add_argument('--cuda', action='store_true', help='enables cuda')
2924
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
3025
parser.add_argument('--fineSize', type=int, default=512, help='resize image to fineSize x fineSize,leave it to 0 if not resize')
@@ -45,70 +40,75 @@
4540
batch_size=1,
4641
shuffle=False)
4742

48-
wct = WCT(args)
49-
def styleTransfer(contentImg,styleImg,imname,csF):
43+
def styleTransfer(wct, contentImg, styleImg, imname):
5044

51-
sF5 = wct.e5(styleImg)
52-
cF5 = wct.e5(contentImg)
53-
sF5 = sF5.data.cpu().squeeze(0)
54-
cF5 = cF5.data.cpu().squeeze(0)
55-
csF5 = wct.transform(cF5,sF5,csF,args.alpha)
45+
sF5 = wct.encoder(styleImg, 'relu5_1')
46+
cF5 = wct.encoder(contentImg, 'relu5_1')
47+
sF5 = sF5.cpu().squeeze(0)
48+
cF5 = cF5.cpu().squeeze(0)
49+
csF5 = wct.transform(cF5,sF5,args.alpha)
50+
csF5 = csF5.to(next(wct.parameters()).device)
5651
Im5 = wct.d5(csF5)
5752

58-
sF4 = wct.e4(styleImg)
59-
cF4 = wct.e4(Im5)
60-
sF4 = sF4.data.cpu().squeeze(0)
61-
cF4 = cF4.data.cpu().squeeze(0)
62-
csF4 = wct.transform(cF4,sF4,csF,args.alpha)
53+
sF4 = wct.encoder(styleImg, 'relu4_1')
54+
cF4 = wct.encoder(Im5, 'relu4_1')
55+
sF4 = sF4.cpu().squeeze(0)
56+
cF4 = cF4.cpu().squeeze(0)
57+
csF4 = wct.transform(cF4,sF4,args.alpha)
58+
csF4 = csF4.to(next(wct.parameters()).device)
6359
Im4 = wct.d4(csF4)
6460

65-
sF3 = wct.e3(styleImg)
66-
cF3 = wct.e3(Im4)
67-
sF3 = sF3.data.cpu().squeeze(0)
68-
cF3 = cF3.data.cpu().squeeze(0)
69-
csF3 = wct.transform(cF3,sF3,csF,args.alpha)
61+
sF3 = wct.encoder(styleImg, 'relu3_1')
62+
cF3 = wct.encoder(Im4, 'relu3_1')
63+
sF3 = sF3.cpu().squeeze(0)
64+
cF3 = cF3.cpu().squeeze(0)
65+
csF3 = wct.transform(cF3,sF3,args.alpha)
66+
csF3 = csF3.to(next(wct.parameters()).device)
7067
Im3 = wct.d3(csF3)
7168

72-
sF2 = wct.e2(styleImg)
73-
cF2 = wct.e2(Im3)
74-
sF2 = sF2.data.cpu().squeeze(0)
75-
cF2 = cF2.data.cpu().squeeze(0)
76-
csF2 = wct.transform(cF2,sF2,csF,args.alpha)
69+
sF2 = wct.encoder(styleImg, 'relu2_1')
70+
cF2 = wct.encoder(Im3, 'relu2_1')
71+
sF2 = sF2.cpu().squeeze(0)
72+
cF2 = cF2.cpu().squeeze(0)
73+
csF2 = wct.transform(cF2,sF2,args.alpha)
74+
csF2 = csF2.to(next(wct.parameters()).device)
7775
Im2 = wct.d2(csF2)
7876

79-
sF1 = wct.e1(styleImg)
80-
cF1 = wct.e1(Im2)
81-
sF1 = sF1.data.cpu().squeeze(0)
82-
cF1 = cF1.data.cpu().squeeze(0)
83-
csF1 = wct.transform(cF1,sF1,csF,args.alpha)
77+
sF1 = wct.encoder(styleImg, 'relu1_1')
78+
cF1 = wct.encoder(Im2, 'relu1_1')
79+
sF1 = sF1.cpu().squeeze(0)
80+
cF1 = cF1.cpu().squeeze(0)
81+
csF1 = wct.transform(cF1,sF1,args.alpha)
82+
csF1 = csF1.to(next(wct.parameters()).device)
8483
Im1 = wct.d1(csF1)
8584
# save_image has this wired design to pad images with 4 pixels at default.
86-
vutils.save_image(Im1.data.cpu().float(),os.path.join(args.outf,imname))
85+
vutils.save_image(Im1.cpu().float(),os.path.join(args.outf,imname))
8786
return
8887

89-
avgTime = 0
90-
cImg = torch.Tensor()
91-
sImg = torch.Tensor()
92-
csF = torch.Tensor()
93-
csF = Variable(csF)
94-
if(args.cuda):
95-
cImg = cImg.cuda(args.gpu)
96-
sImg = sImg.cuda(args.gpu)
97-
csF = csF.cuda(args.gpu)
98-
wct.cuda(args.gpu)
99-
for i,(contentImg,styleImg,imname) in enumerate(loader):
100-
imname = imname[0]
101-
print('Transferring ' + imname)
102-
if (args.cuda):
103-
contentImg = contentImg.cuda(args.gpu)
104-
styleImg = styleImg.cuda(args.gpu)
105-
cImg = Variable(contentImg,volatile=True)
106-
sImg = Variable(styleImg,volatile=True)
107-
start_time = time.time()
108-
# WCT Style Transfer
109-
styleTransfer(cImg,sImg,imname,csF)
110-
end_time = time.time()
111-
print('Elapsed time is: %f' % (end_time - start_time))
112-
avgTime += (end_time - start_time)
88+
def main():
89+
wct = WCT(args)
90+
if(args.cuda):
91+
wct.cuda(args.gpu)
11392

114-
print('Processed %d images. Averaged time is %f' % ((i+1),avgTime/(i+1)))
93+
avgTime = 0
94+
for i,(contentImg,styleImg,imname) in enumerate(loader):
95+
if(args.cuda):
96+
contentImg = contentImg.cuda(args.gpu)
97+
styleImg = styleImg.cuda(args.gpu)
98+
imname = imname[0]
99+
print('Transferring ' + imname)
100+
if (args.cuda):
101+
contentImg = contentImg.cuda(args.gpu)
102+
styleImg = styleImg.cuda(args.gpu)
103+
start_time = time.time()
104+
# WCT Style Transfer
105+
styleTransfer(wct, contentImg, styleImg, imname)
106+
end_time = time.time()
107+
print('Elapsed time is: %f' % (end_time - start_time))
108+
avgTime += (end_time - start_time)
109+
110+
print('Processed %d images. Averaged time is %f' % ((i+1),avgTime/(i+1)))
111+
112+
if __name__ == '__main__':
113+
with torch.no_grad():
114+
main()

0 commit comments

Comments
 (0)