Skip to content

Commit a0e8b29

Browse files
committed
fix some bugs
1 parent 1fe108a commit a0e8b29

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

main.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def parse_command():
4242
parser.add_argument('--seed', type=int, default=1, metavar='S',
4343
help='random seed (default: 1)')
4444
parser.add_argument('--gpu', default=None, type=str, help='if not none, use Single GPU')
45-
parser.add_argument('--print_freq', type=int, default=10, metavar='N',
45+
parser.add_argument('--print_freq', type=int, default=50, metavar='N',
4646
help='how many batches to wait before logging training status')
4747
args = parser.parse_args()
4848

@@ -158,17 +158,17 @@ def main():
158158

159159
train_loss += loss.data[0]
160160
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
161-
per_acc = pred.eq(target.data.view_as(pred)).cpu().sum()
162-
train_acc += per_acc
161+
per_acc = pred.eq(target.data.view_as(pred)).sum()
162+
train_acc += per_acc.cpu()
163163

164164
if it % args.print_freq == 0:
165165
print('=> output: {}'.format(save_dir))
166166
print('Train Iter: [{0}/{1}]\t'
167167
'Loss={Loss:.5f} '
168168
'Accuracy={Acc:.5f}'
169-
.format(it, max_iter, Loss=loss, Acc=per_acc / args.batch_size))
169+
.format(it, max_iter, Loss=loss, Acc=float(per_acc) / args.batch_size))
170170
logger.add_scalar('Train/Loss', loss, it)
171-
logger.add_scalar('Train/Acc', per_acc / args.batch_size, it)
171+
# logger.add_scalar('Train/Acc', per_acc / args.batch_size, it)
172172

173173
if it % iter_save == 0:
174174
epoch = it // iter_save
@@ -180,6 +180,8 @@ def main():
180180
logger.add_scalar('Lr/lr_' + str(i), old_lr, it)
181181

182182
# remember change of train/test loss and train/test acc
183+
train_loss = float(train_loss)
184+
train_acc = float(train_acc)
183185
train_loss /= len(train_loader.dataset)
184186
train_acc /= len(train_loader.dataset)
185187

@@ -212,12 +214,14 @@ def test(model, test_loader, epoch, logger=None):
212214
output = model(data)
213215
test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss
214216
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
215-
correct += pred.eq(target.data.view_as(pred)).cpu().sum()
217+
correct += pred.eq(target.data.view_as(pred)).sum().cpu()
216218

219+
test_loss = float(test_loss)
220+
correct = float(correct)
217221
test_loss /= len(test_loader.dataset)
218222
correct /= len(test_loader.dataset)
219223

220-
print('\nTest set: Average loss: {:.4f}, Accuracy: {:.0f}%\n'.format(test_loss, correct))
224+
print('\nTest set: Average loss: {:.4f}, Accuracy: {:.0f}%\n'.format(test_loss, 100. * correct))
221225

222226
logger.add_scalar('Test/loss', test_loss, epoch)
223227
logger.add_scalar('Test/acc', correct, epoch)

0 commit comments

Comments
 (0)