Skip to content

Commit fbcd06e

Browse files
committed
update network
1 parent a0e8b29 commit fbcd06e

File tree

2 files changed

+3
-7
lines changed

2 files changed

+3
-7
lines changed

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def parse_command():
2626

2727
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
2828
parser.add_argument('--resume', default=None, type=str, metavar='PATH')
29-
parser.add_argument('--model', type=str, default='plain', choices=model_names)
29+
parser.add_argument('--model', type=str, default='deform', choices=model_names)
3030
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
3131
help='input batch size for training (default: 32)')
3232
parser.add_argument('--test-batch-size', type=int, default=32, metavar='N',

network/network.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ def __init__(self):
5858
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
5959
self.bn3 = nn.BatchNorm2d(128)
6060

61-
self.offsets = nn.Conv2d(128, 18, kernel_size=3, padding=1)
6261
self.conv4 = deform_conv.DeformConv2D(128, 128, kernel_size=3, padding=1)
6362
self.bn4 = nn.BatchNorm2d(128)
6463

@@ -73,8 +72,7 @@ def forward(self, x):
7372
x = F.relu(self.conv3(x))
7473
x = self.bn3(x)
7574
# deformable convolution
76-
offsets = self.offsets(x)
77-
x = F.relu(self.conv4(x, offsets))
75+
x = F.relu(self.conv4(x))
7876
x = self.bn4(x)
7977

8078
x = F.avg_pool2d(x, kernel_size=28, stride=1).view(x.size(0), -1)
@@ -95,7 +93,6 @@ def __init__(self):
9593
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
9694
self.bn3 = nn.BatchNorm2d(128)
9795

98-
self.offsets = nn.Conv2d(128, 18, kernel_size=3, padding=1)
9996
self.conv4 = deform_conv_v2.DeformConv2D(128, 128, kernel_size=3, padding=1)
10097
self.bn4 = nn.BatchNorm2d(128)
10198

@@ -110,8 +107,7 @@ def forward(self, x):
110107
x = F.relu(self.conv3(x))
111108
x = self.bn3(x)
112109
# deformable convolution
113-
offsets = self.offsets(x)
114-
x = F.relu(self.conv4(x, offsets))
110+
x = F.relu(self.conv4(x))
115111
x = self.bn4(x)
116112

117113
x = F.avg_pool2d(x, kernel_size=28, stride=1).view(x.size(0), -1)

0 commit comments

Comments
 (0)