@@ -27,13 +27,13 @@ def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, mod
2727 self .zero_padding = nn .ZeroPad2d (padding )
2828 self .conv = nn .Conv2d (inc , outc , kernel_size = kernel_size , stride = kernel_size , bias = bias )
2929
30- self .p_conv = nn .Conv2d (inc , 2 * kernel_size * kernel_size , kernel_size = 3 , padding = 1 , stride = stride )
30+ self .p_conv = nn .Conv2d (inc , 2 * kernel_size * kernel_size , kernel_size = 3 , padding = 1 , stride = stride )
3131 nn .init .constant_ (self .p_conv .weight , 0 )
3232 self .p_conv .register_backward_hook (self ._set_lr )
3333
3434 self .modulation = modulation
3535 if modulation :
36- self .m_conv = nn .Conv2d (inc , kernel_size * kernel_size , kernel_size = 3 , padding = 1 , stride = stride )
36+ self .m_conv = nn .Conv2d (inc , kernel_size * kernel_size , kernel_size = 3 , padding = 1 , stride = stride )
3737 nn .init .constant_ (self .m_conv .weight , 0.5 )
3838 self .m_conv .register_backward_hook (self ._set_lr )
3939
@@ -62,15 +62,13 @@ def forward(self, x):
6262 q_lt = p .detach ().floor ()
6363 q_rb = q_lt + 1
6464
65- q_lt = torch .cat ([torch .clamp (q_lt [..., :N ], 0 , x .size (2 ) - 1 ), torch .clamp (q_lt [..., N :], 0 , x .size (3 ) - 1 )],
66- dim = - 1 ).long ()
67- q_rb = torch .cat ([torch .clamp (q_rb [..., :N ], 0 , x .size (2 ) - 1 ), torch .clamp (q_rb [..., N :], 0 , x .size (3 ) - 1 )],
68- dim = - 1 ).long ()
65+ q_lt = torch .cat ([torch .clamp (q_lt [..., :N ], 0 , x .size (2 )- 1 ), torch .clamp (q_lt [..., N :], 0 , x .size (3 )- 1 )], dim = - 1 ).long ()
66+ q_rb = torch .cat ([torch .clamp (q_rb [..., :N ], 0 , x .size (2 )- 1 ), torch .clamp (q_rb [..., N :], 0 , x .size (3 )- 1 )], dim = - 1 ).long ()
6967 q_lb = torch .cat ([q_lt [..., :N ], q_rb [..., N :]], dim = - 1 )
7068 q_rt = torch .cat ([q_rb [..., :N ], q_lt [..., N :]], dim = - 1 )
7169
7270 # clip p
73- p = torch .cat ([torch .clamp (p [..., :N ], 0 , x .size (2 ) - 1 ), torch .clamp (p [..., N :], 0 , x .size (3 ) - 1 )], dim = - 1 )
71+ p = torch .cat ([torch .clamp (p [..., :N ], 0 , x .size (2 )- 1 ), torch .clamp (p [..., N :], 0 , x .size (3 )- 1 )], dim = - 1 )
7472
7573 # bilinear kernel (b, h, w, N)
7674 g_lt = (1 + (q_lt [..., :N ].type_as (p ) - p [..., :N ])) * (1 + (q_lt [..., N :].type_as (p ) - p [..., N :]))
@@ -104,26 +102,26 @@ def forward(self, x):
104102
105103 def _get_p_n (self , N , dtype ):
106104 p_n_x , p_n_y = torch .meshgrid (
107- torch .arange (- (self .kernel_size - 1 ) // 2 , (self .kernel_size - 1 ) // 2 + 1 ),
108- torch .arange (- (self .kernel_size - 1 ) // 2 , (self .kernel_size - 1 ) // 2 + 1 ))
105+ torch .arange (- (self .kernel_size - 1 ) // 2 , (self .kernel_size - 1 ) // 2 + 1 ),
106+ torch .arange (- (self .kernel_size - 1 ) // 2 , (self .kernel_size - 1 ) // 2 + 1 ))
109107 # (2N, 1)
110108 p_n = torch .cat ([torch .flatten (p_n_x ), torch .flatten (p_n_y )], 0 )
111- p_n = p_n .view (1 , 2 * N , 1 , 1 ).type (dtype )
109+ p_n = p_n .view (1 , 2 * N , 1 , 1 ).type (dtype )
112110
113111 return p_n
114112
115113 def _get_p_0 (self , h , w , N , dtype ):
116114 p_0_x , p_0_y = torch .meshgrid (
117- torch .arange (1 , h * self .stride + 1 , self .stride ),
118- torch .arange (1 , w * self .stride + 1 , self .stride ))
115+ torch .arange (1 , h * self .stride + 1 , self .stride ),
116+ torch .arange (1 , w * self .stride + 1 , self .stride ))
119117 p_0_x = torch .flatten (p_0_x ).view (1 , 1 , h , w ).repeat (1 , N , 1 , 1 )
120118 p_0_y = torch .flatten (p_0_y ).view (1 , 1 , h , w ).repeat (1 , N , 1 , 1 )
121119 p_0 = torch .cat ([p_0_x , p_0_y ], 1 ).type (dtype )
122120
123121 return p_0
124122
125123 def _get_p (self , offset , dtype ):
126- N , h , w = offset .size (1 ) // 2 , offset .size (2 ), offset .size (3 )
124+ N , h , w = offset .size (1 )// 2 , offset .size (2 ), offset .size (3 )
127125
128126 # (1, 2N, 1, 1)
129127 p_n = self ._get_p_n (N , dtype )
@@ -140,7 +138,7 @@ def _get_x_q(self, x, q, N):
140138 x = x .contiguous ().view (b , c , - 1 )
141139
142140 # (b, h, w, N)
143- index = q [..., :N ] * padded_w + q [..., N :] # offset_x*w + offset_y
141+ index = q [..., :N ]* padded_w + q [..., N :] # offset_x*w + offset_y
144142 # (b, c, h*w*N)
145143 index = index .contiguous ().unsqueeze (dim = 1 ).expand (- 1 , c , - 1 , - 1 , - 1 ).contiguous ().view (b , c , - 1 )
146144
@@ -151,9 +149,8 @@ def _get_x_q(self, x, q, N):
151149 @staticmethod
152150 def _reshape_x_offset (x_offset , ks ):
153151 b , c , h , w , N = x_offset .size ()
154- x_offset = torch .cat ([x_offset [..., s :s + ks ].contiguous ().view (b , c , h , w * ks ) for s in range (0 , N , ks )],
155- dim = - 1 )
156- x_offset = x_offset .contiguous ().view (b , c , h * ks , w * ks )
152+ x_offset = torch .cat ([x_offset [..., s :s + ks ].contiguous ().view (b , c , h , w * ks ) for s in range (0 , N , ks )], dim = - 1 )
153+ x_offset = x_offset .contiguous ().view (b , c , h * ks , w * ks )
157154
158155 return x_offset
159156
0 commit comments