Skip to content

Commit d5f3f10

Browse files
author
Jaan Altosaar
committed
Merge branch 'master' of github.com:altosaar/vae
2 parents cdf9000 + 517c52b commit d5f3f10

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

train_variational_autoencoder_pytorch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
test_batch_size: 512
2929
max_iterations: 100000
3030
log_interval: 10000
31+
early_stopping_interval: 5
3132
n_samples: 128
3233
use_gpu: true
3334
train_dir: $TMPDIR
@@ -188,6 +189,7 @@ def evaluate(n_samples, model, variational, eval_data):
188189
if __name__ == '__main__':
189190
dictionary = yaml.load(config)
190191
cfg = nomen.Config(dictionary)
192+
cfg.parse_args()
191193
device = torch.device("cuda:0" if cfg.use_gpu else "cpu")
192194
torch.manual_seed(cfg.seed)
193195
np.random.seed(cfg.seed)
@@ -246,7 +248,7 @@ def evaluate(n_samples, model, variational, eval_data):
246248
else:
247249
num_no_improvement += 1
248250

249-
if num_no_improvement > 5:
251+
if num_no_improvement > cfg.early_stopping_interval:
250252
checkpoint = torch.load(cfg.train_dir / 'best_state_dict')
251253
model.load_state_dict(checkpoint['model'])
252254
variational.load_state_dict(checkpoint['variational'])

0 commit comments

Comments
 (0)