|
7 | 7 | import h5py |
8 | 8 |
|
9 | 9 |
|
10 | | -def parse_binary_mnist(): |
| 10 | +def parse_binary_mnist(data_dir): |
11 | 11 | def lines_to_np_array(lines): |
12 | 12 | return np.array([[int(i) for i in line.split()] for line in lines]) |
13 | | - with open(os.path.join(DATASETS_DIR, 'binarized_mnist_train.amat')) as f: |
| 13 | + with open(os.path.join(data_dir, 'binarized_mnist_train.amat')) as f: |
14 | 14 | lines = f.readlines() |
15 | 15 | train_data = lines_to_np_array(lines).astype('float32') |
16 | | - with open(os.path.join(DATASETS_DIR, 'binarized_mnist_valid.amat')) as f: |
| 16 | + with open(os.path.join(data_dir, 'binarized_mnist_valid.amat')) as f: |
17 | 17 | lines = f.readlines() |
18 | 18 | validation_data = lines_to_np_array(lines).astype('float32') |
19 | | - with open(os.path.join(DATASETS_DIR, 'binarized_mnist_test.amat')) as f: |
| 19 | + with open(os.path.join(data_dir, 'binarized_mnist_test.amat')) as f: |
20 | 20 | lines = f.readlines() |
21 | 21 | test_data = lines_to_np_array(lines).astype('float32') |
22 | 22 | return train_data, validation_data, test_data |
23 | 23 |
|
24 | 24 |
|
25 | 25 | def download_binary_mnist(fname): |
26 | | - DATASETS_DIR = '/tmp/' |
| 26 | + data_dir = '/tmp/' |
27 | 27 | subdatasets = ['train', 'valid', 'test'] |
28 | 28 | for subdataset in subdatasets: |
29 | 29 | filename = 'binarized_mnist_{}.amat'.format(subdataset) |
30 | 30 | url = 'http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_{}.amat'.format( |
31 | 31 | subdataset) |
32 | | - local_filename = os.path.join(DATASETS_DIR, filename) |
| 32 | + local_filename = os.path.join(data_dir, filename) |
33 | 33 | urllib.request.urlretrieve(url, local_filename) |
34 | 34 |
|
35 | | - train, validation, test = parse_binary_mnist() |
| 35 | + train, validation, test = parse_binary_mnist(data_dir) |
36 | 36 |
|
37 | 37 | data_dict = {'train': train, 'valid': validation, 'test': test} |
38 | 38 | f = h5py.File(fname, 'w') |
39 | 39 | f.create_dataset('train', data=data_dict['train']) |
40 | 40 | f.create_dataset('valid', data=data_dict['valid']) |
41 | 41 | f.create_dataset('test', data=data_dict['test']) |
42 | 42 | f.close() |
| 43 | + print(f'Saved binary MNIST data to: {fname}') |
0 commit comments