Skip to content

Commit a7a8e29

Browse files
implement Lu et al transfer
1 parent dd58d67 commit a7a8e29

File tree

2 files changed

+36
-11
lines changed

2 files changed

+36
-11
lines changed

WCT.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import torch
55
import argparse
6+
import pprint
67
from PIL import Image
78
from torch.autograd import Variable
89
import torchvision.utils as vutils
@@ -23,6 +24,9 @@
2324
parser.add_argument('--decoder2', default='models/vgg19_normalized_decoder2.pth.tar', help='Path to the decoder2')
2425
parser.add_argument('--decoder1', default='models/vgg19_normalized_decoder1.pth.tar', help='Path to the decoder1')
2526
parser.add_argument('--cuda', action='store_true', help='enables cuda')
27+
parser.add_argument('--transform-method', choices=['original', 'closed-form'], default='original',
28+
help=('How to whiten and color the features. "original" for the formulation of Li et al. ( https://arxiv.org/abs/1705.08086 ) '
29+
'or "closed-form" for method of Lu et al. ( https://arxiv.org/abs/1906.00668 '))
2630
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
2731
parser.add_argument('--fineSize', type=int, default=512, help='resize image to fineSize x fineSize,leave it to 0 if not resize')
2832
parser.add_argument('--outf', default='samples/', help='folder to output images')
@@ -32,6 +36,7 @@
3236
parser.add_argument('--gpu', type=int, default=0, help="which gpu to run on. default is 0")
3337

3438
args = parser.parse_args()
39+
pprint.pprint(args.__dict__, indent=2)
3540

3641
try:
3742
os.makedirs(args.outf)
@@ -44,7 +49,7 @@
4449
batch_size=1,
4550
shuffle=False)
4651

47-
def styleTransfer(wct, targets, contentImg, styleImg, imname, gamma, delta, outf):
52+
def styleTransfer(wct, targets, contentImg, styleImg, imname, gamma, delta, outf, transform_method):
4853

4954
current_result = contentImg
5055
eIorigs = [f.cpu().squeeze(0) for f in wct.encoder(contentImg, targets)]
@@ -58,8 +63,8 @@ def styleTransfer(wct, targets, contentImg, styleImg, imname, gamma, delta, outf
5863
else:
5964
eIlast = wct.encoder(current_result, target).cpu().squeeze(0)
6065

61-
CsIlast = wct.transform(eIlast, eIs).float()
62-
CsIorig = wct.transform(eIorig, eIs).float()
66+
CsIlast = wct.transform(eIlast, eIs, transform_method).float()
67+
CsIorig = wct.transform(eIorig, eIs, transform_method).float()
6368

6469
decoder_input = (gamma*(delta * CsIlast + (1-delta) * CsIorig) \
6570
+ (1-gamma) * eIorig)
@@ -91,7 +96,7 @@ def main():
9196
# WCT Style Transfer
9297
targets = [f'relu{t}_1' for t in args.targets]
9398
styleTransfer(wct, targets, contentImg, styleImg, imname,
94-
args.gamma, args.delta, args.outf)
99+
args.gamma, args.delta, args.outf, args.transform_method)
95100
end_time = time.time()
96101
print(' Elapsed time is: %f' % (end_time - start_time))
97102
avgTime += (end_time - start_time)

util.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212

1313

1414
def matrix_sqrt(A):
15+
A = A.clone()
16+
a_diag_ = A.diagonal()
17+
a_diag_ += 1e-4
18+
1519
s_u, s_e, s_v = torch.svd(A,some=False)
1620

1721
k_s = A.shape[-1]
@@ -27,6 +31,9 @@ def matrix_sqrt(A):
2731

2832

2933
def matrix_inv_sqrt(A):
34+
A = A.clone()
35+
a_diag_ = A.diagonal()
36+
a_diag_ += 1e-4
3037
k_c = A.shape[-1]
3138
c_u,c_e,c_v = torch.svd(A, some=False)
3239

@@ -65,36 +72,49 @@ def __init__(self,args):
6572
'relu4_1': self.d4,
6673
'relu5_1': self.d5}
6774

68-
def whiten_and_color(self,cF,sF):
75+
def whiten_and_color(self,cF,sF, method):
6976
cFSize = cF.size()
77+
print(f'cF.shape = {cF.shape}')
7078
c_mean = torch.mean(cF,1) # c x (h x w)
7179
c_mean = c_mean.unsqueeze(1).expand_as(cF)
7280
cF = cF - c_mean
7381

7482
contentConv = torch.mm(cF,cF.t()).div(cFSize[1]-1) + torch.eye(cFSize[0]).double()
75-
cF_inv_sqrt = matrix_inv_sqrt(contentConv)
7683

7784
sFSize = sF.size()
7885
s_mean = torch.mean(sF,1)
7986
sF = sF - s_mean.unsqueeze(1).expand_as(sF)
8087
styleConv = torch.mm(sF,sF.t()).div(sFSize[1]-1)
81-
sF_sqrt = matrix_sqrt(styleConv)
8288

83-
whiten_cF = torch.mm(cF_inv_sqrt, cF)
89+
if method == 'original': # the original WCT by Li et al.
90+
cF_inv_sqrt = matrix_inv_sqrt(contentConv)
91+
sF_sqrt = matrix_sqrt(styleConv)
92+
# whiten_cF = torch.mm(cF_inv_sqrt, cF)
93+
# targetFeature = torch.mm(sF_sqrt,whiten_cF)
94+
targetFeature = sF_sqrt @ (cF_inv_sqrt @ cF)
95+
else: # Lu et al.
96+
assert method == 'closed-form'
97+
cF_sqrt = matrix_sqrt(contentConv)
98+
cF_inv_sqrt = matrix_inv_sqrt(contentConv)
99+
print(f'cF_sqrt.shape = {cF_sqrt.shape}')
100+
middle_matrix = matrix_sqrt(cF_sqrt @ styleConv @ cF_sqrt)
101+
print(f'middle_matrix.shape = {middle_matrix.shape}')
102+
transform_matrix = cF_inv_sqrt @ middle_matrix @ cF_inv_sqrt
103+
targetFeature = transform_matrix @ cF
104+
print(f'targetFeature.shape = {targetFeature.shape}')
84105

85-
targetFeature = torch.mm(sF_sqrt,whiten_cF)
86106
targetFeature = targetFeature + s_mean.unsqueeze(1).expand_as(targetFeature)
87107
return targetFeature
88108

89-
def transform(self, cF, sF):
109+
def transform(self, cF, sF, method):
90110
cF = cF.double()
91111
sF = sF.double()
92112
C,W,H = cF.size(0),cF.size(1),cF.size(2)
93113
_,W1,H1 = sF.size(0),sF.size(1),sF.size(2)
94114
cFView = cF.view(C,-1)
95115
sFView = sF.view(C,-1)
96116

97-
targetFeature = self.whiten_and_color(cFView,sFView)
117+
targetFeature = self.whiten_and_color(cFView, sFView, method)
98118
targetFeature = targetFeature.view_as(cF)
99119
return targetFeature
100120

0 commit comments

Comments
 (0)