|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | + @Time : 2019/2/20 22:16 |
| 4 | + @Author : Wang Xin |
| 5 | + @Email : wangxin_buaa@163.com |
| 6 | +""" |
| 7 | + |
| 8 | +import numpy as np |
| 9 | + |
| 10 | +import torch |
| 11 | +import torch.nn as nn |
| 12 | + |
| 13 | + |
| 14 | +class DeformConv2D(nn.Module): |
| 15 | + def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None): |
| 16 | + super(DeformConv2D, self).__init__() |
| 17 | + self.kernel_size = kernel_size |
| 18 | + self.padding = padding |
| 19 | + self.stride = stride |
| 20 | + self.zero_padding = nn.ZeroPad2d(padding) |
| 21 | + self.conv_kernel = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias) |
| 22 | + |
| 23 | + def forward(self, x, offset): |
| 24 | + dtype = offset.data.type() |
| 25 | + ks = self.kernel_size |
| 26 | + N = offset.size(1) // 2 |
| 27 | + |
| 28 | + # Change offset's order from [x1, x2, ..., y1, y2, ...] to [x1, y1, x2, y2, ...] |
| 29 | + # Codes below are written to make sure same results of MXNet implementation. |
| 30 | + # You can remove them, and it won't influence the module's performance. |
| 31 | + offsets_index = torch.cat([torch.arange(0, 2 * N, 2), torch.arange(1, 2 * N + 1, 2)]).type_as(x).long() |
| 32 | + offsets_index.requires_grad = False |
| 33 | + offsets_index = offsets_index.unsqueeze(dim=0).unsqueeze(dim=-1).unsqueeze(dim=-1).expand(*offset.size()) |
| 34 | + offset = torch.gather(offset, dim=1, index=offsets_index) |
| 35 | + # ------------------------------------------------------------------------ |
| 36 | + |
| 37 | + if self.padding: |
| 38 | + x = self.zero_padding(x) |
| 39 | + |
| 40 | + # (b, 2N, h, w) |
| 41 | + p = self._get_p(offset, dtype) |
| 42 | + |
| 43 | + # (b, h, w, 2N) |
| 44 | + p = p.contiguous().permute(0, 2, 3, 1) |
| 45 | + |
| 46 | + """ |
| 47 | + if q is float, using bilinear interpolate, it has four integer corresponding position. |
| 48 | + The four position is left top, right top, left bottom, right bottom, defined as q_lt, q_rb, q_lb, q_rt |
| 49 | + """ |
| 50 | + # (b, h, w, 2N) |
| 51 | + q_lt = p.detach().floor() |
| 52 | + |
| 53 | + """ |
| 54 | + Because the shape of x is N, b, h, w, the pixel position is (y, x) |
| 55 | + *┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄→y |
| 56 | + ┊ .(y, x) .(y+1, x) |
| 57 | + ┊ |
| 58 | + ┊ .(y, x+1) .(y+1, x+1) |
| 59 | + ┊ |
| 60 | + ↓ |
| 61 | + x |
| 62 | +
|
| 63 | + For right bottom point, it'x = left top'y + 1, it'y = left top'y + 1 |
| 64 | + """ |
| 65 | + q_rb = q_lt + 1 |
| 66 | + |
| 67 | + """ |
| 68 | + x.size(2) is h, x.size(3) is w, make 0 <= p_y <= h - 1, 0 <= p_x <= w-1 |
| 69 | + """ |
| 70 | + q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2) - 1), torch.clamp(q_lt[..., N:], 0, x.size(3) - 1)], |
| 71 | + dim=-1).long() |
| 72 | + |
| 73 | + """ |
| 74 | + x.size(2) is h, x.size(3) is w, make 0 <= p_y <= h - 1, 0 <= p_x <= w-1 |
| 75 | + """ |
| 76 | + q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2) - 1), torch.clamp(q_rb[..., N:], 0, x.size(3) - 1)], |
| 77 | + dim=-1).long() |
| 78 | + |
| 79 | + """ |
| 80 | + For the left bottom point, it'x is equal to right bottom, it'y is equal to left top |
| 81 | + Therefore, it's y is from q_lt, it's x is from q_rb |
| 82 | + """ |
| 83 | + q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], -1) |
| 84 | + |
| 85 | + """ |
| 86 | + y from q_rb, x from q_lt |
| 87 | + For right top point, it's x is equal t to left top, it's y is equal to right bottom |
| 88 | + """ |
| 89 | + q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], -1) |
| 90 | + |
| 91 | + """ |
| 92 | + find p_y <= padding or p_y >= h - 1 - padding, find p_x <= padding or p_x >= x - 1 - padding |
| 93 | + This is to find the points in the area where the pixel value is meaningful. |
| 94 | + """ |
| 95 | + # (b, h, w, N) |
| 96 | + mask = torch.cat([p[..., :N].lt(self.padding) + p[..., :N].gt(x.size(2) - 1 - self.padding), |
| 97 | + p[..., N:].lt(self.padding) + p[..., N:].gt(x.size(3) - 1 - self.padding)], dim=-1).type_as(p) |
| 98 | + mask = mask.detach() |
| 99 | + # print('mask:', mask) |
| 100 | + |
| 101 | + floor_p = torch.floor(p) |
| 102 | + # print('floor_p = ', floor_p) |
| 103 | + |
| 104 | + """ |
| 105 | + when mask is 1, take floor_p; |
| 106 | + when mask is 0, take original p. |
| 107 | + When thr point in the padding area, interpolation is not meaningful and we can take the nearest |
| 108 | + point which is the most possible to have meaningful value. |
| 109 | + """ |
| 110 | + p = p * (1 - mask) + floor_p * mask |
| 111 | + p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2) - 1), torch.clamp(p[..., N:], 0, x.size(3) - 1)], dim=-1) |
| 112 | + |
| 113 | + """ |
| 114 | + In the paper, G(q, p) = g(q_x, p_x) * g(q_y, p_y) |
| 115 | + g(a, b) = max(0, 1-|a-b|) |
| 116 | + """ |
| 117 | + # bilinear kernel (b, h, w, N) |
| 118 | + g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:])) |
| 119 | + g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:])) |
| 120 | + g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:])) |
| 121 | + g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:])) |
| 122 | + |
| 123 | + # print('g_lt size is ', g_lt.size()) |
| 124 | + # print('g_lt unsqueeze size:', g_lt.unsqueeze(dim=1).size()) |
| 125 | + |
| 126 | + # (b, c, h, w, N) |
| 127 | + x_q_lt = self._get_x_q(x, q_lt, N) |
| 128 | + x_q_rb = self._get_x_q(x, q_rb, N) |
| 129 | + x_q_lb = self._get_x_q(x, q_lb, N) |
| 130 | + x_q_rt = self._get_x_q(x, q_rt, N) |
| 131 | + |
| 132 | + """ |
| 133 | + In the paper, x(p) = ΣG(p, q) * x(q), G is bilinear kernal |
| 134 | + """ |
| 135 | + # (b, c, h, w, N) |
| 136 | + x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \ |
| 137 | + g_rb.unsqueeze(dim=1) * x_q_rb + \ |
| 138 | + g_lb.unsqueeze(dim=1) * x_q_lb + \ |
| 139 | + g_rt.unsqueeze(dim=1) * x_q_rt |
| 140 | + |
| 141 | + x_offset = self._reshape_x_offset(x_offset, ks) |
| 142 | + |
| 143 | + out = self.conv_kernel(x_offset) |
| 144 | + return x_offset |
| 145 | + |
| 146 | + def _get_p_n(self, N, dtype): |
| 147 | + """ |
| 148 | + In torch 0.4.1 grid_x, grid_y = torch.meshgrid([x, y]) |
| 149 | + In torch 1.0 grid_x, grid_y = torch.meshgrid(x, y) |
| 150 | + """ |
| 151 | + p_n_x, p_n_y = torch.meshgrid( |
| 152 | + [torch.arange(-(self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2 + 1), |
| 153 | + torch.arange(-(self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2 + 1)]) |
| 154 | + # (2N, 1) |
| 155 | + p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0) |
| 156 | + p_n = p_n.view(1, 2 * N, 1, 1).type(dtype) |
| 157 | + p_n.requires_grad = False |
| 158 | + # print('requires_grad:', p_n.requires_grad) |
| 159 | + |
| 160 | + return p_n |
| 161 | + |
| 162 | + def _get_p_0(self, h, w, N, dtype): |
| 163 | + p_0_x, p_0_y = torch.meshgrid([ |
| 164 | + torch.arange(1, h * self.stride + 1, self.stride), |
| 165 | + torch.arange(1, w * self.stride + 1, self.stride)]) |
| 166 | + p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1) |
| 167 | + p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1) |
| 168 | + p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype) |
| 169 | + p_0.requires_grad = False |
| 170 | + |
| 171 | + return p_0 |
| 172 | + |
| 173 | + def _get_p(self, offset, dtype): |
| 174 | + N, h, w = offset.size(1) // 2, offset.size(2), offset.size(3) |
| 175 | + |
| 176 | + # (1, 2N, 1, 1) |
| 177 | + p_n = self._get_p_n(N, dtype) |
| 178 | + |
| 179 | + # (1, 2N, h, w) |
| 180 | + p_0 = self._get_p_0(h, w, N, dtype) |
| 181 | + |
| 182 | + p = p_0 + p_n + offset |
| 183 | + return p |
| 184 | + |
| 185 | + def _get_x_q(self, x, q, N): |
| 186 | + b, h, w, _ = q.size() |
| 187 | + padded_w = x.size(3) |
| 188 | + c = x.size(1) |
| 189 | + # (b, c, h*w) |
| 190 | + x = x.contiguous().view(b, c, -1) |
| 191 | + |
| 192 | + # (b, h, w, N) |
| 193 | + index = q[..., :N] * padded_w + q[..., N:] # offset_x*w + offset_y |
| 194 | + # (b, c, h*w*N) |
| 195 | + index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1) |
| 196 | + |
| 197 | + x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N) |
| 198 | + |
| 199 | + return x_offset |
| 200 | + |
| 201 | + @staticmethod |
| 202 | + def _reshape_x_offset(x_offset, ks): |
| 203 | + b, c, h, w, N = x_offset.size() |
| 204 | + x_offset = torch.cat([x_offset[..., s:s + ks].contiguous().view(b, c, h, w * ks) for s in range(0, N, ks)], |
| 205 | + dim=-1) |
| 206 | + x_offset = x_offset.contiguous().view(b, c, h * ks, w * ks) |
| 207 | + |
| 208 | + return x_offset |
| 209 | + |
| 210 | + |
| 211 | +from network.deform_conv.deform_conv import DeformConv2D as DeformConv2D_ori |
| 212 | +from time import time |
| 213 | + |
| 214 | +if __name__ == '__main__': |
| 215 | + x = torch.randn(4, 3, 255, 255) |
| 216 | + |
| 217 | + p_conv = nn.Conv2d(3, 2 * 3 * 3, kernel_size=3, padding=1, stride=1) |
| 218 | + conv = nn.Conv2d(3, 64, kernel_size=3, stride=3, bias=False) |
| 219 | + |
| 220 | + d_conv1 = DeformConv2D(3, 64) |
| 221 | + d_conv2 = DeformConv2D_ori(3, 64) |
| 222 | + |
| 223 | + offset = p_conv(x) |
| 224 | + |
| 225 | + end = time() |
| 226 | + y1 = conv(d_conv1(x, offset)) |
| 227 | + end = time() - end |
| 228 | + print('#1 speed = ', end) |
| 229 | + |
| 230 | + end = time() |
| 231 | + y2 = conv(d_conv2(x, offset)) |
| 232 | + end = time() - end |
| 233 | + print('#2 speed = ', end) |
| 234 | + |
| 235 | + # mask = (y1 == y2) |
| 236 | + # print(mask) |
| 237 | + # print(torch.max(mask)) |
| 238 | + # print(torch.min(mask)) |
| 239 | + |
0 commit comments