Skip to content

Commit aea497d

Browse files
committed
update network
1 parent fbcd06e commit aea497d

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def parse_command():
2626

2727
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
2828
parser.add_argument('--resume', default=None, type=str, metavar='PATH')
29-
parser.add_argument('--model', type=str, default='deform', choices=model_names)
29+
parser.add_argument('--model', type=str, default='deform_v2', choices=model_names)
3030
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
3131
help='input batch size for training (default: 32)')
3232
parser.add_argument('--test-batch-size', type=int, default=32, metavar='N',

network/deform_conv/deform_conv_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
"""
1414

1515

16-
class DeformConv2d(nn.Module):
16+
class DeformConv2D(nn.Module):
1717
def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=False, lr_ratio = 0.1):
1818
"""
1919
Args:
2020
modulation (bool, optional): If True, Modulated Defomable Convolution (Deformable ConvNets v2).
2121
"""
22-
super(DeformConv2d, self).__init__()
22+
super(DeformConv2D, self).__init__()
2323
self.kernel_size = kernel_size
2424
self.padding = padding
2525
self.stride = stride

network/network.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def forward(self, x):
8383

8484
class DeformNet_v2(nn.Module):
8585
def __init__(self):
86-
super(DeformNet, self).__init__()
86+
super(DeformNet_v2, self).__init__()
8787
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
8888
self.bn1 = nn.BatchNorm2d(32)
8989

@@ -93,7 +93,7 @@ def __init__(self):
9393
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
9494
self.bn3 = nn.BatchNorm2d(128)
9595

96-
self.conv4 = deform_conv_v2.DeformConv2D(128, 128, kernel_size=3, padding=1)
96+
self.conv4 = deform_conv_v2.DeformConv2D(128, 128, kernel_size=3, padding=1, modulation=True)
9797
self.bn4 = nn.BatchNorm2d(128)
9898

9999
self.classifier = nn.Linear(128, 10)

0 commit comments

Comments
 (0)