44 @Author : Wang Xin
55 @Email : wangxin_buaa@163.com
66"""
7+ import argparse
8+ import os
79
8- # TODO: 手写体数字识别
10+ import torch
11+ import torch .nn as nn
12+ from torch .optim import lr_scheduler
13+ import torch .nn .functional as F
14+ from torchvision import datasets , transforms
15+
16+ import torch .optim as optim
17+
18+ from tqdm import tqdm
19+ from network .network import PlainNet , DeformNet , DeformNet_v2
20+
21+ from utils import utils
22+
23+
24+ def parse_command ():
25+ model_names = ['plain' , 'deform' , 'deform_v2' ]
26+
27+ parser = argparse .ArgumentParser (description = 'PyTorch MNIST Example' )
28+ parser .add_argument ('--resume' , default = None , type = str , metavar = 'PATH' )
29+ parser .add_argument ('--model' , type = str , default = 'plain' , choices = model_names )
30+ parser .add_argument ('--batch-size' , type = int , default = 32 , metavar = 'N' ,
31+ help = 'input batch size for training (default: 32)' )
32+ parser .add_argument ('--test-batch-size' , type = int , default = 32 , metavar = 'N' ,
33+ help = 'input batch size for testing (default: 32)' )
34+ parser .add_argument ('--epochs' , type = int , default = 10 , metavar = 'N' ,
35+ help = 'number of epochs to train (default: 10)' )
36+ parser .add_argument ('--lr' , type = float , default = 0.01 , metavar = 'LR' ,
37+ help = 'learning rate (default: 0.01)' )
38+ parser .add_argument ('--lr_patience' , default = 2 , type = int ,
39+ help = 'Patience of LR scheduler. See documentation of ReduceLROnPlateau.' )
40+ parser .add_argument ('--momentum' , type = float , default = 0.5 , metavar = 'M' ,
41+ help = 'SGD momentum (default: 0.5)' )
42+ parser .add_argument ('--seed' , type = int , default = 1 , metavar = 'S' ,
43+ help = 'random seed (default: 1)' )
44+ parser .add_argument ('--gpu' , default = None , type = str , help = 'if not none, use Single GPU' )
45+ parser .add_argument ('--print_freq' , type = int , default = 10 , metavar = 'N' ,
46+ help = 'how many batches to wait before logging training status' )
47+ args = parser .parse_args ()
48+
49+ return args
50+
51+
52+ def create_mnist_loader (args ):
53+ # MNIST Dataset
54+ train_dataset = datasets .MNIST (root = './data/' ,
55+ train = True ,
56+ transform = transforms .ToTensor (),
57+ download = True )
58+
59+ test_dataset = datasets .MNIST (root = './data/' ,
60+ train = False ,
61+ transform = transforms .ToTensor ())
62+
63+ # Data Loader (Input Pipeline)
64+ train_loader = torch .utils .data .DataLoader (dataset = train_dataset ,
65+ batch_size = args .batch_size ,
66+ shuffle = True )
67+
68+ test_loader = torch .utils .data .DataLoader (dataset = test_dataset ,
69+ batch_size = args .batch_size ,
70+ shuffle = False )
71+ return train_loader , test_loader
72+
73+
74+ def init_weight (net ):
75+ for m in net .modules ():
76+ if isinstance (m , nn .Conv2d ):
77+ torch .nn .init .xavier_normal_ (m .weight )
78+
79+ if m .bias is not None :
80+ m .bias .data .zero_ ()
81+ elif isinstance (m , nn .BatchNorm2d ):
82+ m .weight .data .fill_ (1 )
83+ m .bias .data .zero_ ()
84+
85+
86+ def main ():
87+ args = parse_command ()
88+ print (args )
89+
90+ # if setting gpu id, the using single GPU
91+ if args .gpu :
92+ print ('Single GPU Mode.' )
93+ os .environ ["CUDA_VISIBLE_DEVICES" ] = args .gpu
94+
95+ # set random seed
96+ torch .manual_seed (args .seed )
97+ torch .cuda .manual_seed (args .seed )
98+
99+ if torch .cuda .device_count () > 1 :
100+ print ("Let's use" , torch .cuda .device_count (), "GPUs!" )
101+ args .batch_size = args .batch_size * torch .cuda .device_count ()
102+ else :
103+ print ("Let's use GPU " , torch .cuda .current_device ())
104+
105+ train_loader , test_loader = create_mnist_loader (args )
106+
107+ # create save dir and logger
108+ save_dir = utils .get_save_path (args )
109+ utils .write_config_file (args , save_dir )
110+ logger = utils .get_logger (save_dir )
111+
112+ best_result = 0.0
113+ best_txt = os .path .join (save_dir , 'best.txt' )
114+
115+ train_acc = 0.0
116+ train_loss = 0.0
117+
118+ start_epoch = 0
119+ start_iter = len (train_loader ) * start_epoch + 1
120+ max_iter = len (train_loader ) * (args .epochs - start_epoch + 1 ) + 1
121+ iter_save = len (train_loader )
122+
123+ if args .model == 'plain' :
124+ model = PlainNet ()
125+ elif args .model == 'deform' :
126+ model = DeformNet ()
127+ else :
128+ model = DeformNet_v2 ()
129+ model .apply (init_weight )
130+ # You can use DataParallel() whether you use Multi-GPUs or not
131+ model = nn .DataParallel (model ).cuda ()
132+
133+ optimizer = optim .SGD (model .parameters (), lr = args .lr , momentum = args .momentum )
134+ criterion = nn .NLLLoss ()
135+
136+ # when training, use reduceLROnPlateau to reduce learning rate
137+ scheduler = lr_scheduler .ReduceLROnPlateau (
138+ optimizer , 'max' , patience = args .lr_patience )
139+
140+ model .train ()
141+
142+ for it in tqdm (range (start_iter , max_iter + 1 ), total = max_iter , leave = False , dynamic_ncols = True ):
143+ optimizer .zero_grad ()
144+
145+ try :
146+ input , target = next (loader_iter )
147+ except :
148+ loader_iter = iter (train_loader )
149+ input , target = next (loader_iter )
150+
151+ input = input .cuda ()
152+ target = target .cuda ()
153+
154+ output = model (input )
155+ loss = criterion (output , target )
156+ loss .backward ()
157+ optimizer .step ()
158+
159+ train_loss += loss .data [0 ]
160+ pred = output .data .max (1 , keepdim = True )[1 ] # get the index of the max log-probability
161+ per_acc = pred .eq (target .data .view_as (pred )).cpu ().sum ()
162+ train_acc += per_acc
163+
164+ if it % args .print_freq == 0 :
165+ print ('=> output: {}' .format (save_dir ))
166+ print ('Train Iter: [{0}/{1}]\t '
167+ 'Loss={Loss:.5f} '
168+ 'Accuracy={Acc:.5f}'
169+ .format (it , max_iter , Loss = loss , Acc = per_acc / args .batch_size ))
170+ logger .add_scalar ('Train/Loss' , loss , it )
171+ logger .add_scalar ('Train/Acc' , per_acc / args .batch_size , it )
172+
173+ if it % iter_save == 0 :
174+ epoch = it // iter_save
175+ correct , test_loss = test (model , test_loader , it , logger )
176+
177+ # save the change of learning_rate
178+ for i , param_group in enumerate (optimizer .param_groups ):
179+ old_lr = float (param_group ['lr' ])
180+ logger .add_scalar ('Lr/lr_' + str (i ), old_lr , it )
181+
182+ # remember change of train/test loss and train/test acc
183+ train_loss /= len (train_loader .dataset )
184+ train_acc /= len (train_loader .dataset )
185+
186+ logger .add_scalars ('TrainVal/acc' , {'train_acc' : train_acc , 'test_acc' : correct }, epoch )
187+ logger .add_scalars ('TrainVal/loss' , {'train_loss' : train_loss , 'test_loss' : test_loss }, epoch )
188+
189+ train_loss = 0.0
190+ train_acc = 0.0
191+
192+ # remember best rmse and save checkpoint
193+ is_best = correct > best_result
194+ if is_best :
195+ best_result = correct
196+ with open (best_txt , 'w' ) as txtfile :
197+ txtfile .write ("epoch={}, acc={}" .format (epoch , correct ))
198+
199+ scheduler .step (correct )
200+
201+ model .train ()
202+
203+ logger .close ()
204+
205+
206+ def test (model , test_loader , epoch , logger = None ):
207+ model .eval ()
208+ test_loss = 0.0
209+ correct = 0.0
210+ for data , target in test_loader :
211+ data , target = data .cuda (), target .cuda ()
212+ output = model (data )
213+ test_loss += F .nll_loss (output , target , size_average = False ).data [0 ] # sum up batch loss
214+ pred = output .data .max (1 , keepdim = True )[1 ] # get the index of the max log-probability
215+ correct += pred .eq (target .data .view_as (pred )).cpu ().sum ()
216+
217+ test_loss /= len (test_loader .dataset )
218+ correct /= len (test_loader .dataset )
219+
220+ print ('\n Test set: Average loss: {:.4f}, Accuracy: {:.0f}%\n ' .format (test_loss , correct ))
221+
222+ logger .add_scalar ('Test/loss' , test_loss , epoch )
223+ logger .add_scalar ('Test/acc' , correct , epoch )
224+
225+ return float (correct ), float (test_loss )
226+
227+
228+ if __name__ == '__main__' :
229+ main ()
0 commit comments