Skip to content

Commit cfc456a

Browse files
committed
Updated data location
1 parent b239ce7 commit cfc456a

File tree

2 files changed

+122
-6
lines changed

2 files changed

+122
-6
lines changed

0_Running_TensorFlow_In_SageMaker.ipynb

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@
3636
"metadata": {},
3737
"source": [
3838
"## Getting the data\n",
39-
"Copy the cifar10 tfrecord datasets from s3://floor28/data/cifar10 to your local notebook\n",
39+
"To use the CIFAR-10 dataset, we first need to download it and convert it to TFRecords. This step takes around 5 minutes.\n",
4040
"\n",
41-
"You can use the following AWS CLI command:"
41+
"You can use the following command:"
4242
]
4343
},
4444
{
@@ -47,7 +47,7 @@
4747
"metadata": {},
4848
"outputs": [],
4949
"source": [
50-
"!aws s3 cp --recursive s3://floor28/data/cifar10 ./data"
50+
"!python generate_cifar10_tfrecords.py --data-dir ./data"
5151
]
5252
},
5353
{
@@ -422,13 +422,13 @@
422422
"pycharm": {
423423
"stem_cell": {
424424
"cell_type": "raw",
425-
"source": [],
426425
"metadata": {
427426
"collapsed": false
428-
}
427+
},
428+
"source": []
429429
}
430430
}
431431
},
432432
"nbformat": 4,
433433
"nbformat_minor": 4
434-
}
434+
}

generate_cifar10_tfrecords.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# https://aws.amazon.com/apache-2-0/
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import argparse
16+
import os
17+
import shutil
18+
import sys
19+
import tarfile
20+
21+
import tensorflow as tf
22+
from six.moves import cPickle as pickle
23+
from six.moves import xrange
24+
25+
CIFAR_FILENAME = 'cifar-10-python.tar.gz'
26+
CIFAR_DOWNLOAD_URL = 'https://www.cs.toronto.edu/~kriz/' + CIFAR_FILENAME
27+
CIFAR_LOCAL_FOLDER = 'cifar-10-batches-py'
28+
29+
30+
def download_and_extract(data_dir):
31+
# download CIFAR-10 if not already downloaded.
32+
tf.contrib.learn.datasets.base.maybe_download(CIFAR_FILENAME, data_dir, CIFAR_DOWNLOAD_URL)
33+
tarfile.open(os.path.join(data_dir, CIFAR_FILENAME), 'r:gz').extractall(data_dir)
34+
35+
36+
def _int64_feature(value):
37+
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
38+
39+
40+
def _bytes_feature(value):
41+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
42+
43+
44+
def _get_file_names():
45+
"""Returns the file names expected to exist in the input_dir."""
46+
return {
47+
'train': ['data_batch_%d' % i for i in xrange(1, 5)],
48+
'validation': ['data_batch_5'],
49+
'eval': ['test_batch'],
50+
}
51+
52+
53+
def read_pickle_from_file(filename):
54+
with tf.io.gfile.GFile(filename, 'rb') as f:
55+
if sys.version_info.major >= 3:
56+
return pickle.load(f, encoding='bytes')
57+
else:
58+
return pickle.load(f)
59+
60+
61+
def convert_to_tfrecord(input_files, output_file):
62+
"""Converts a file to TFRecords."""
63+
print('Generating %s' % output_file)
64+
with tf.io.TFRecordWriter(output_file) as record_writer:
65+
for input_file in input_files:
66+
data_dict = read_pickle_from_file(input_file)
67+
data = data_dict[b'data']
68+
labels = data_dict[b'labels']
69+
70+
num_entries_in_batch = len(labels)
71+
for i in range(num_entries_in_batch):
72+
example = tf.train.Example(features=tf.train.Features(
73+
feature={
74+
'image': _bytes_feature(data[i].tobytes()),
75+
'label': _int64_feature(labels[i])
76+
}))
77+
record_writer.write(example.SerializeToString())
78+
79+
80+
def main(data_dir):
81+
print('Download from {} and extract.'.format(CIFAR_DOWNLOAD_URL))
82+
download_and_extract(data_dir)
83+
84+
file_names = _get_file_names()
85+
input_dir = os.path.join(data_dir, CIFAR_LOCAL_FOLDER)
86+
for mode, files in file_names.items():
87+
input_files = [os.path.join(input_dir, f) for f in files]
88+
89+
mode_dir = os.path.join(data_dir, mode)
90+
output_file = os.path.join(mode_dir, mode + '.tfrecords')
91+
if not os.path.exists(mode_dir):
92+
os.makedirs(mode_dir)
93+
try:
94+
os.remove(output_file)
95+
except OSError:
96+
pass
97+
98+
# Convert to tf.train.Example and write the to TFRecords.
99+
convert_to_tfrecord(input_files, output_file)
100+
101+
print('Done!')
102+
shutil.rmtree(os.path.join(data_dir, 'cifar-10-batches-py'))
103+
os.remove(os.path.join(data_dir, 'cifar-10-python.tar.gz')) # Remove the original .tzr.gz files
104+
105+
106+
if __name__ == '__main__':
107+
parser = argparse.ArgumentParser()
108+
parser.add_argument(
109+
'--data-dir',
110+
type=str,
111+
default='',
112+
help='Directory to download and extract CIFAR-10 to.'
113+
)
114+
115+
args = parser.parse_args()
116+
main(args.data_dir)

0 commit comments

Comments
 (0)