Skip to content

Commit 8d1c764

Browse files
author
Jaan Altosaar
committed
rename DATASETS_DIR to data_dir
1 parent d5f3f10 commit 8d1c764

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

data.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,36 +7,37 @@
77
import h5py
88

99

10-
def parse_binary_mnist():
10+
def parse_binary_mnist(data_dir):
1111
def lines_to_np_array(lines):
1212
return np.array([[int(i) for i in line.split()] for line in lines])
13-
with open(os.path.join(DATASETS_DIR, 'binarized_mnist_train.amat')) as f:
13+
with open(os.path.join(data_dir, 'binarized_mnist_train.amat')) as f:
1414
lines = f.readlines()
1515
train_data = lines_to_np_array(lines).astype('float32')
16-
with open(os.path.join(DATASETS_DIR, 'binarized_mnist_valid.amat')) as f:
16+
with open(os.path.join(data_dir, 'binarized_mnist_valid.amat')) as f:
1717
lines = f.readlines()
1818
validation_data = lines_to_np_array(lines).astype('float32')
19-
with open(os.path.join(DATASETS_DIR, 'binarized_mnist_test.amat')) as f:
19+
with open(os.path.join(data_dir, 'binarized_mnist_test.amat')) as f:
2020
lines = f.readlines()
2121
test_data = lines_to_np_array(lines).astype('float32')
2222
return train_data, validation_data, test_data
2323

2424

2525
def download_binary_mnist(fname):
26-
DATASETS_DIR = '/tmp/'
26+
data_dir = '/tmp/'
2727
subdatasets = ['train', 'valid', 'test']
2828
for subdataset in subdatasets:
2929
filename = 'binarized_mnist_{}.amat'.format(subdataset)
3030
url = 'http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_{}.amat'.format(
3131
subdataset)
32-
local_filename = os.path.join(DATASETS_DIR, filename)
32+
local_filename = os.path.join(data_dir, filename)
3333
urllib.request.urlretrieve(url, local_filename)
3434

35-
train, validation, test = parse_binary_mnist()
35+
train, validation, test = parse_binary_mnist(data_dir)
3636

3737
data_dict = {'train': train, 'valid': validation, 'test': test}
3838
f = h5py.File(fname, 'w')
3939
f.create_dataset('train', data=data_dict['train'])
4040
f.create_dataset('valid', data=data_dict['valid'])
4141
f.create_dataset('test', data=data_dict['test'])
4242
f.close()
43+
print(f'Saved binary MNIST data to: {fname}')

train_variational_autoencoder_pytorch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def cycle(iterable):
152152
def load_binary_mnist(cfg, **kwcfg):
153153
fname = cfg.data_dir / 'binary_mnist.h5'
154154
if not fname.exists():
155+
print('Downloading binary MNIST data...')
155156
data.download_binary_mnist(fname)
156157
f = h5py.File(pathlib.os.path.join(pathlib.os.environ['DAT'], 'binarized_mnist.hdf5'), 'r')
157158
x_train = f['train'][::]

0 commit comments

Comments
 (0)