|
8 | 8 | from Loader import Dataset |
9 | 9 | from util import * |
10 | 10 | import scipy.misc |
11 | | -from torch.utils.serialization import load_lua |
12 | 11 | import time |
13 | 12 |
|
14 | 13 | parser = argparse.ArgumentParser(description='WCT Pytorch') |
15 | 14 | parser.add_argument('--contentPath',default='images/content',help='path to train') |
16 | 15 | parser.add_argument('--stylePath',default='images/style',help='path to train') |
17 | 16 | parser.add_argument('--workers', default=2, type=int, metavar='N',help='number of data loading workers (default: 4)') |
18 | | -parser.add_argument('--vgg1', default='models/vgg_normalised_conv1_1.t7', help='Path to the VGG conv1_1') |
19 | | -parser.add_argument('--vgg2', default='models/vgg_normalised_conv2_1.t7', help='Path to the VGG conv2_1') |
20 | | -parser.add_argument('--vgg3', default='models/vgg_normalised_conv3_1.t7', help='Path to the VGG conv3_1') |
21 | | -parser.add_argument('--vgg4', default='models/vgg_normalised_conv4_1.t7', help='Path to the VGG conv4_1') |
22 | | -parser.add_argument('--vgg5', default='models/vgg_normalised_conv5_1.t7', help='Path to the VGG conv5_1') |
23 | | -parser.add_argument('--decoder5', default='models/feature_invertor_conv5_1.t7', help='Path to the decoder5') |
24 | | -parser.add_argument('--decoder4', default='models/feature_invertor_conv4_1.t7', help='Path to the decoder4') |
25 | | -parser.add_argument('--decoder3', default='models/feature_invertor_conv3_1.t7', help='Path to the decoder3') |
26 | | -parser.add_argument('--decoder2', default='models/feature_invertor_conv2_1.t7', help='Path to the decoder2') |
27 | | -parser.add_argument('--decoder1', default='models/feature_invertor_conv1_1.t7', help='Path to the decoder1') |
| 17 | +parser.add_argument('--encoder', default='models/vgg19_normalized.pth.tar', help='Path to the VGG conv1_1') |
| 18 | +parser.add_argument('--decoder5', default='models/vgg19_normalized_decoder5.pth.tar', help='Path to the decoder5') |
| 19 | +parser.add_argument('--decoder4', default='models/vgg19_normalized_decoder4.pth.tar', help='Path to the decoder4') |
| 20 | +parser.add_argument('--decoder3', default='models/vgg19_normalized_decoder3.pth.tar', help='Path to the decoder3') |
| 21 | +parser.add_argument('--decoder2', default='models/vgg19_normalized_decoder2.pth.tar', help='Path to the decoder2') |
| 22 | +parser.add_argument('--decoder1', default='models/vgg19_normalized_decoder1.pth.tar', help='Path to the decoder1') |
28 | 23 | parser.add_argument('--cuda', action='store_true', help='enables cuda') |
29 | 24 | parser.add_argument('--batch_size', type=int, default=1, help='batch size') |
30 | 25 | parser.add_argument('--fineSize', type=int, default=512, help='resize image to fineSize x fineSize,leave it to 0 if not resize') |
|
45 | 40 | batch_size=1, |
46 | 41 | shuffle=False) |
47 | 42 |
|
48 | | -wct = WCT(args) |
49 | | -def styleTransfer(contentImg,styleImg,imname,csF): |
| 43 | +def styleTransfer(wct, contentImg, styleImg, imname): |
50 | 44 |
|
51 | | - sF5 = wct.e5(styleImg) |
52 | | - cF5 = wct.e5(contentImg) |
53 | | - sF5 = sF5.data.cpu().squeeze(0) |
54 | | - cF5 = cF5.data.cpu().squeeze(0) |
55 | | - csF5 = wct.transform(cF5,sF5,csF,args.alpha) |
| 45 | + sF5 = wct.encoder(styleImg, 'relu5_1') |
| 46 | + cF5 = wct.encoder(contentImg, 'relu5_1') |
| 47 | + sF5 = sF5.cpu().squeeze(0) |
| 48 | + cF5 = cF5.cpu().squeeze(0) |
| 49 | + csF5 = wct.transform(cF5,sF5,args.alpha) |
| 50 | + csF5 = csF5.to(next(wct.parameters()).device) |
56 | 51 | Im5 = wct.d5(csF5) |
57 | 52 |
|
58 | | - sF4 = wct.e4(styleImg) |
59 | | - cF4 = wct.e4(Im5) |
60 | | - sF4 = sF4.data.cpu().squeeze(0) |
61 | | - cF4 = cF4.data.cpu().squeeze(0) |
62 | | - csF4 = wct.transform(cF4,sF4,csF,args.alpha) |
| 53 | + sF4 = wct.encoder(styleImg, 'relu4_1') |
| 54 | + cF4 = wct.encoder(Im5, 'relu4_1') |
| 55 | + sF4 = sF4.cpu().squeeze(0) |
| 56 | + cF4 = cF4.cpu().squeeze(0) |
| 57 | + csF4 = wct.transform(cF4,sF4,args.alpha) |
| 58 | + csF4 = csF4.to(next(wct.parameters()).device) |
63 | 59 | Im4 = wct.d4(csF4) |
64 | 60 |
|
65 | | - sF3 = wct.e3(styleImg) |
66 | | - cF3 = wct.e3(Im4) |
67 | | - sF3 = sF3.data.cpu().squeeze(0) |
68 | | - cF3 = cF3.data.cpu().squeeze(0) |
69 | | - csF3 = wct.transform(cF3,sF3,csF,args.alpha) |
| 61 | + sF3 = wct.encoder(styleImg, 'relu3_1') |
| 62 | + cF3 = wct.encoder(Im4, 'relu3_1') |
| 63 | + sF3 = sF3.cpu().squeeze(0) |
| 64 | + cF3 = cF3.cpu().squeeze(0) |
| 65 | + csF3 = wct.transform(cF3,sF3,args.alpha) |
| 66 | + csF3 = csF3.to(next(wct.parameters()).device) |
70 | 67 | Im3 = wct.d3(csF3) |
71 | 68 |
|
72 | | - sF2 = wct.e2(styleImg) |
73 | | - cF2 = wct.e2(Im3) |
74 | | - sF2 = sF2.data.cpu().squeeze(0) |
75 | | - cF2 = cF2.data.cpu().squeeze(0) |
76 | | - csF2 = wct.transform(cF2,sF2,csF,args.alpha) |
| 69 | + sF2 = wct.encoder(styleImg, 'relu2_1') |
| 70 | + cF2 = wct.encoder(Im3, 'relu2_1') |
| 71 | + sF2 = sF2.cpu().squeeze(0) |
| 72 | + cF2 = cF2.cpu().squeeze(0) |
| 73 | + csF2 = wct.transform(cF2,sF2,args.alpha) |
| 74 | + csF2 = csF2.to(next(wct.parameters()).device) |
77 | 75 | Im2 = wct.d2(csF2) |
78 | 76 |
|
79 | | - sF1 = wct.e1(styleImg) |
80 | | - cF1 = wct.e1(Im2) |
81 | | - sF1 = sF1.data.cpu().squeeze(0) |
82 | | - cF1 = cF1.data.cpu().squeeze(0) |
83 | | - csF1 = wct.transform(cF1,sF1,csF,args.alpha) |
| 77 | + sF1 = wct.encoder(styleImg, 'relu1_1') |
| 78 | + cF1 = wct.encoder(Im2, 'relu1_1') |
| 79 | + sF1 = sF1.cpu().squeeze(0) |
| 80 | + cF1 = cF1.cpu().squeeze(0) |
| 81 | + csF1 = wct.transform(cF1,sF1,args.alpha) |
| 82 | + csF1 = csF1.to(next(wct.parameters()).device) |
84 | 83 | Im1 = wct.d1(csF1) |
85 | 84 | # save_image has this wired design to pad images with 4 pixels at default. |
86 | | - vutils.save_image(Im1.data.cpu().float(),os.path.join(args.outf,imname)) |
| 85 | + vutils.save_image(Im1.cpu().float(),os.path.join(args.outf,imname)) |
87 | 86 | return |
88 | 87 |
|
89 | | -avgTime = 0 |
90 | | -cImg = torch.Tensor() |
91 | | -sImg = torch.Tensor() |
92 | | -csF = torch.Tensor() |
93 | | -csF = Variable(csF) |
94 | | -if(args.cuda): |
95 | | - cImg = cImg.cuda(args.gpu) |
96 | | - sImg = sImg.cuda(args.gpu) |
97 | | - csF = csF.cuda(args.gpu) |
98 | | - wct.cuda(args.gpu) |
99 | | -for i,(contentImg,styleImg,imname) in enumerate(loader): |
100 | | - imname = imname[0] |
101 | | - print('Transferring ' + imname) |
102 | | - if (args.cuda): |
103 | | - contentImg = contentImg.cuda(args.gpu) |
104 | | - styleImg = styleImg.cuda(args.gpu) |
105 | | - cImg = Variable(contentImg,volatile=True) |
106 | | - sImg = Variable(styleImg,volatile=True) |
107 | | - start_time = time.time() |
108 | | - # WCT Style Transfer |
109 | | - styleTransfer(cImg,sImg,imname,csF) |
110 | | - end_time = time.time() |
111 | | - print('Elapsed time is: %f' % (end_time - start_time)) |
112 | | - avgTime += (end_time - start_time) |
| 88 | +def main(): |
| 89 | + wct = WCT(args) |
| 90 | + if(args.cuda): |
| 91 | + wct.cuda(args.gpu) |
113 | 92 |
|
114 | | -print('Processed %d images. Averaged time is %f' % ((i+1),avgTime/(i+1))) |
| 93 | + avgTime = 0 |
| 94 | + for i,(contentImg,styleImg,imname) in enumerate(loader): |
| 95 | + if(args.cuda): |
| 96 | + contentImg = contentImg.cuda(args.gpu) |
| 97 | + styleImg = styleImg.cuda(args.gpu) |
| 98 | + imname = imname[0] |
| 99 | + print('Transferring ' + imname) |
| 100 | + if (args.cuda): |
| 101 | + contentImg = contentImg.cuda(args.gpu) |
| 102 | + styleImg = styleImg.cuda(args.gpu) |
| 103 | + start_time = time.time() |
| 104 | + # WCT Style Transfer |
| 105 | + styleTransfer(wct, contentImg, styleImg, imname) |
| 106 | + end_time = time.time() |
| 107 | + print('Elapsed time is: %f' % (end_time - start_time)) |
| 108 | + avgTime += (end_time - start_time) |
| 109 | + |
| 110 | + print('Processed %d images. Averaged time is %f' % ((i+1),avgTime/(i+1))) |
| 111 | + |
| 112 | +if __name__ == '__main__': |
| 113 | + with torch.no_grad(): |
| 114 | + main() |
0 commit comments