|
| 1 | +"""Get the binarized MNIST dataset and convert to hdf5. |
| 2 | +From https://github.com/yburda/iwae/blob/master/datasets.py |
| 3 | +""" |
| 4 | +import urllib.request |
| 5 | +import os |
| 6 | +import numpy as np |
| 7 | +import h5py |
| 8 | + |
| 9 | + |
| 10 | +def parse_binary_mnist(): |
| 11 | + def lines_to_np_array(lines): |
| 12 | + return np.array([[int(i) for i in line.split()] for line in lines]) |
| 13 | + with open(os.path.join(DATASETS_DIR, 'binarized_mnist_train.amat')) as f: |
| 14 | + lines = f.readlines() |
| 15 | + train_data = lines_to_np_array(lines).astype('float32') |
| 16 | + with open(os.path.join(DATASETS_DIR, 'binarized_mnist_valid.amat')) as f: |
| 17 | + lines = f.readlines() |
| 18 | + validation_data = lines_to_np_array(lines).astype('float32') |
| 19 | + with open(os.path.join(DATASETS_DIR, 'binarized_mnist_test.amat')) as f: |
| 20 | + lines = f.readlines() |
| 21 | + test_data = lines_to_np_array(lines).astype('float32') |
| 22 | + return train_data, validation_data, test_data |
| 23 | + |
| 24 | + |
| 25 | +def download_binary_mnist(fname): |
| 26 | + DATASETS_DIR = '/tmp/' |
| 27 | + subdatasets = ['train', 'valid', 'test'] |
| 28 | + for subdataset in subdatasets: |
| 29 | + filename = 'binarized_mnist_{}.amat'.format(subdataset) |
| 30 | + url = 'http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_{}.amat'.format( |
| 31 | + subdataset) |
| 32 | + local_filename = os.path.join(DATASETS_DIR, filename) |
| 33 | + urllib.request.urlretrieve(url, local_filename) |
| 34 | + |
| 35 | + train, validation, test = parse_binary_mnist() |
| 36 | + |
| 37 | + data_dict = {'train': train, 'valid': validation, 'test': test} |
| 38 | + f = h5py.File(fname, 'w') |
| 39 | + f.create_dataset('train', data=data_dict['train']) |
| 40 | + f.create_dataset('valid', data=data_dict['valid']) |
| 41 | + f.create_dataset('test', data=data_dict['test']) |
| 42 | + f.close() |
0 commit comments