Skip to content

Commit 1fe108a

Browse files
committed
finish main
1 parent ca30f59 commit 1fe108a

File tree

8 files changed

+516
-399
lines changed

8 files changed

+516
-399
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ MANIFEST
3232
*.pptx
3333
*.caffemodel
3434
result/
35+
data/
3536

3637
# PyInstaller
3738
# Usually these files are written by a python script from a template

main.py

Lines changed: 222 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,226 @@
44
@Author : Wang Xin
55
@Email : wangxin_buaa@163.com
66
"""
7+
import argparse
8+
import os
79

8-
# TODO: 手写体数字识别
10+
import torch
11+
import torch.nn as nn
12+
from torch.optim import lr_scheduler
13+
import torch.nn.functional as F
14+
from torchvision import datasets, transforms
15+
16+
import torch.optim as optim
17+
18+
from tqdm import tqdm
19+
from network.network import PlainNet, DeformNet, DeformNet_v2
20+
21+
from utils import utils
22+
23+
24+
def parse_command():
25+
model_names = ['plain', 'deform', 'deform_v2']
26+
27+
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
28+
parser.add_argument('--resume', default=None, type=str, metavar='PATH')
29+
parser.add_argument('--model', type=str, default='plain', choices=model_names)
30+
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
31+
help='input batch size for training (default: 32)')
32+
parser.add_argument('--test-batch-size', type=int, default=32, metavar='N',
33+
help='input batch size for testing (default: 32)')
34+
parser.add_argument('--epochs', type=int, default=10, metavar='N',
35+
help='number of epochs to train (default: 10)')
36+
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
37+
help='learning rate (default: 0.01)')
38+
parser.add_argument('--lr_patience', default=2, type=int,
39+
help='Patience of LR scheduler. See documentation of ReduceLROnPlateau.')
40+
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
41+
help='SGD momentum (default: 0.5)')
42+
parser.add_argument('--seed', type=int, default=1, metavar='S',
43+
help='random seed (default: 1)')
44+
parser.add_argument('--gpu', default=None, type=str, help='if not none, use Single GPU')
45+
parser.add_argument('--print_freq', type=int, default=10, metavar='N',
46+
help='how many batches to wait before logging training status')
47+
args = parser.parse_args()
48+
49+
return args
50+
51+
52+
def create_mnist_loader(args):
53+
# MNIST Dataset
54+
train_dataset = datasets.MNIST(root='./data/',
55+
train=True,
56+
transform=transforms.ToTensor(),
57+
download=True)
58+
59+
test_dataset = datasets.MNIST(root='./data/',
60+
train=False,
61+
transform=transforms.ToTensor())
62+
63+
# Data Loader (Input Pipeline)
64+
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
65+
batch_size=args.batch_size,
66+
shuffle=True)
67+
68+
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
69+
batch_size=args.batch_size,
70+
shuffle=False)
71+
return train_loader, test_loader
72+
73+
74+
def init_weight(net):
75+
for m in net.modules():
76+
if isinstance(m, nn.Conv2d):
77+
torch.nn.init.xavier_normal_(m.weight)
78+
79+
if m.bias is not None:
80+
m.bias.data.zero_()
81+
elif isinstance(m, nn.BatchNorm2d):
82+
m.weight.data.fill_(1)
83+
m.bias.data.zero_()
84+
85+
86+
def main():
87+
args = parse_command()
88+
print(args)
89+
90+
# if setting gpu id, the using single GPU
91+
if args.gpu:
92+
print('Single GPU Mode.')
93+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
94+
95+
# set random seed
96+
torch.manual_seed(args.seed)
97+
torch.cuda.manual_seed(args.seed)
98+
99+
if torch.cuda.device_count() > 1:
100+
print("Let's use", torch.cuda.device_count(), "GPUs!")
101+
args.batch_size = args.batch_size * torch.cuda.device_count()
102+
else:
103+
print("Let's use GPU ", torch.cuda.current_device())
104+
105+
train_loader, test_loader = create_mnist_loader(args)
106+
107+
# create save dir and logger
108+
save_dir = utils.get_save_path(args)
109+
utils.write_config_file(args, save_dir)
110+
logger = utils.get_logger(save_dir)
111+
112+
best_result = 0.0
113+
best_txt = os.path.join(save_dir, 'best.txt')
114+
115+
train_acc = 0.0
116+
train_loss = 0.0
117+
118+
start_epoch = 0
119+
start_iter = len(train_loader) * start_epoch + 1
120+
max_iter = len(train_loader) * (args.epochs - start_epoch + 1) + 1
121+
iter_save = len(train_loader)
122+
123+
if args.model == 'plain':
124+
model = PlainNet()
125+
elif args.model == 'deform':
126+
model = DeformNet()
127+
else:
128+
model = DeformNet_v2()
129+
model.apply(init_weight)
130+
# You can use DataParallel() whether you use Multi-GPUs or not
131+
model = nn.DataParallel(model).cuda()
132+
133+
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
134+
criterion = nn.NLLLoss()
135+
136+
# when training, use reduceLROnPlateau to reduce learning rate
137+
scheduler = lr_scheduler.ReduceLROnPlateau(
138+
optimizer, 'max', patience=args.lr_patience)
139+
140+
model.train()
141+
142+
for it in tqdm(range(start_iter, max_iter + 1), total=max_iter, leave=False, dynamic_ncols=True):
143+
optimizer.zero_grad()
144+
145+
try:
146+
input, target = next(loader_iter)
147+
except:
148+
loader_iter = iter(train_loader)
149+
input, target = next(loader_iter)
150+
151+
input = input.cuda()
152+
target = target.cuda()
153+
154+
output = model(input)
155+
loss = criterion(output, target)
156+
loss.backward()
157+
optimizer.step()
158+
159+
train_loss += loss.data[0]
160+
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
161+
per_acc = pred.eq(target.data.view_as(pred)).cpu().sum()
162+
train_acc += per_acc
163+
164+
if it % args.print_freq == 0:
165+
print('=> output: {}'.format(save_dir))
166+
print('Train Iter: [{0}/{1}]\t'
167+
'Loss={Loss:.5f} '
168+
'Accuracy={Acc:.5f}'
169+
.format(it, max_iter, Loss=loss, Acc=per_acc / args.batch_size))
170+
logger.add_scalar('Train/Loss', loss, it)
171+
logger.add_scalar('Train/Acc', per_acc / args.batch_size, it)
172+
173+
if it % iter_save == 0:
174+
epoch = it // iter_save
175+
correct, test_loss = test(model, test_loader, it, logger)
176+
177+
# save the change of learning_rate
178+
for i, param_group in enumerate(optimizer.param_groups):
179+
old_lr = float(param_group['lr'])
180+
logger.add_scalar('Lr/lr_' + str(i), old_lr, it)
181+
182+
# remember change of train/test loss and train/test acc
183+
train_loss /= len(train_loader.dataset)
184+
train_acc /= len(train_loader.dataset)
185+
186+
logger.add_scalars('TrainVal/acc', {'train_acc': train_acc, 'test_acc': correct}, epoch)
187+
logger.add_scalars('TrainVal/loss', {'train_loss': train_loss, 'test_loss': test_loss}, epoch)
188+
189+
train_loss = 0.0
190+
train_acc = 0.0
191+
192+
# remember best rmse and save checkpoint
193+
is_best = correct > best_result
194+
if is_best:
195+
best_result = correct
196+
with open(best_txt, 'w') as txtfile:
197+
txtfile.write("epoch={}, acc={}".format(epoch, correct))
198+
199+
scheduler.step(correct)
200+
201+
model.train()
202+
203+
logger.close()
204+
205+
206+
def test(model, test_loader, epoch, logger=None):
207+
model.eval()
208+
test_loss = 0.0
209+
correct = 0.0
210+
for data, target in test_loader:
211+
data, target = data.cuda(), target.cuda()
212+
output = model(data)
213+
test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss
214+
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
215+
correct += pred.eq(target.data.view_as(pred)).cpu().sum()
216+
217+
test_loss /= len(test_loader.dataset)
218+
correct /= len(test_loader.dataset)
219+
220+
print('\nTest set: Average loss: {:.4f}, Accuracy: {:.0f}%\n'.format(test_loss, correct))
221+
222+
logger.add_scalar('Test/loss', test_loss, epoch)
223+
logger.add_scalar('Test/acc', correct, epoch)
224+
225+
return float(correct), float(test_loss)
226+
227+
228+
if __name__ == '__main__':
229+
main()

0 commit comments

Comments
 (0)