Skip to content

Commit 1ee8fba

Browse files
authored
Add files via upload
1. Renamed to indicate 'sdp' for SM Dist Data Parallel. 2. Support for both scenarios of training scripts for Debugger
1 parent 06130bf commit 1ee8fba

File tree

3 files changed

+1914
-0
lines changed

3 files changed

+1914
-0
lines changed
Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
"""
2+
Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
SPDX-License-Identifier: MIT-0
4+
"""
5+
import argparse
6+
import numpy as np
7+
import os
8+
import logging
9+
import time
10+
import tensorflow as tf
11+
import smdistributed.dataparallel
12+
import smdistributed.dataparallel.tensorflow as sdp
13+
import tensorflow.config.experimental as exp
14+
from tensorflow.data import Dataset
15+
from tensorflow.keras.layers import (Input, Conv2D, MaxPooling2D, Flatten,
16+
Dense, Dropout, BatchNormalization)
17+
from tensorflow.keras.models import Sequential
18+
from tensorflow.keras.optimizers import Adam
19+
from tensorflow.keras.callbacks import ModelCheckpoint
20+
from tensorflow.keras.losses import SparseCategoricalCrossentropy
21+
from tensorflow.keras.metrics import SparseCategoricalAccuracy
22+
from tensorflow.train import Checkpoint
23+
24+
# Declare constants
25+
TRAIN_VERBOSE_LEVEL = 0
26+
EVALUATE_VERBOSE_LEVEL = 0
27+
IMAGE_HEIGHT, IMAGE_WIDTH, NUM_CHANNELS, NUM_CLASSES = 28, 28, 1, 10
28+
VALIDATION_DATA_SPLIT = 0.1
29+
30+
# Create the logger
31+
logger = logging.getLogger(__name__)
32+
logger.setLevel(int(os.environ.get('SM_LOG_LEVEL', logging.INFO)))
33+
34+
35+
## Parse and load the command-line arguments sent to the script
36+
## These will be sent by SageMaker when it launches the training container
37+
def parse_args():
38+
logger.info('Parsing command-line arguments...')
39+
parser = argparse.ArgumentParser()
40+
# Hyperparameters
41+
parser.add_argument('--epochs', type=int, default=1)
42+
parser.add_argument('--batch_size', type=int, default=64)
43+
parser.add_argument('--learning_rate', type=float, default=0.1)
44+
parser.add_argument('--decay', type=float, default=1e-6)
45+
# Data directories
46+
parser.add_argument('--train', type=str, default=os.environ.get('SM_CHANNEL_TRAIN'))
47+
parser.add_argument('--test', type=str, default=os.environ.get('SM_CHANNEL_TEST'))
48+
# Model output directory
49+
parser.add_argument('--model_dir', type=str, default=os.environ.get('SM_MODEL_DIR'))
50+
# Checkpoint info
51+
parser.add_argument('--checkpoint_enabled', type=str, default='True')
52+
parser.add_argument('--checkpoint_load_previous', type=str, default='True')
53+
parser.add_argument('--checkpoint_local_dir', type=str, default='/opt/ml/checkpoints/')
54+
logger.info('Completed parsing command-line arguments.')
55+
return parser.parse_known_args()
56+
57+
58+
## Initialize the SMDataParallel environment
59+
def init_sdp():
60+
logger.info('Initializing the SMDataParallel environment...')
61+
tf.random.set_seed(42)
62+
sdp.init()
63+
logger.debug('Getting GPU list...')
64+
gpus = exp.list_physical_devices('GPU')
65+
logger.debug('Number of GPUs = {}'.format(len(gpus)))
66+
logger.debug('Completed getting GPU list.')
67+
logger.debug('Enabling memory growth on all GPUs...')
68+
for gpu in gpus:
69+
exp.set_memory_growth(gpu, True)
70+
logger.debug('Completed enabling memory growth on all GPUs.')
71+
logger.debug('Pinning GPUs to a single SMDataParallel process...')
72+
if gpus:
73+
exp.set_visible_devices(gpus[sdp.local_rank()], 'GPU')
74+
logger.debug('Completed pinning GPUs to a single SMDataParallel process.')
75+
logger.info('Completed initializing the SMDataParallel environment.')
76+
77+
78+
## Load data from local directory to memory and preprocess
79+
def load_and_preprocess_data(data_type, data_dir, x_data_file_name, y_data_file_name):
80+
logger.info('GPU # {} :: Loading and preprocessing {} data...'.format(sdp.rank(), data_type))
81+
x_data = np.load(os.path.join(data_dir, x_data_file_name))
82+
x_data = np.reshape(x_data, (x_data.shape[0], x_data.shape[1], x_data.shape[2], 1))
83+
y_data = np.load(os.path.join(data_dir, y_data_file_name))
84+
logger.info('GPU # {} :: Completed loading and preprocessing {} data.'.format(sdp.rank(), data_type))
85+
return x_data, y_data
86+
87+
88+
## Construct the network
89+
def create_model():
90+
logger.info('GPU # {} :: Creating the model...'.format(sdp.rank()))
91+
model = Sequential([
92+
Conv2D(64, kernel_size=(3, 3), activation='relu', padding='same',
93+
input_shape=(IMAGE_HEIGHT, IMAGE_WIDTH, NUM_CHANNELS)),
94+
BatchNormalization(),
95+
Conv2D(64, kernel_size=(3, 3), activation='relu'),
96+
BatchNormalization(),
97+
MaxPooling2D(pool_size=(2, 2)),
98+
99+
Conv2D(128, kernel_size=(3, 3), activation='relu', padding='same'),
100+
BatchNormalization(),
101+
Conv2D(128, kernel_size=(3, 3), activation='relu'),
102+
BatchNormalization(),
103+
MaxPooling2D(pool_size=(2, 2)),
104+
105+
Conv2D(256, kernel_size=(3, 3), activation='relu', padding='same'),
106+
BatchNormalization(),
107+
Conv2D(256, kernel_size=(3, 3), activation='relu'),
108+
BatchNormalization(),
109+
MaxPooling2D(pool_size=(2, 2)),
110+
111+
Flatten(),
112+
113+
Dense(1024, activation='relu'),
114+
115+
Dense(512, activation='relu'),
116+
117+
Dense(NUM_CLASSES, activation='softmax')
118+
])
119+
# Print the model summary
120+
logger.info(model.summary())
121+
logger.info('GPU # {} :: Completed creating the model.'.format(sdp.rank()))
122+
return model
123+
124+
125+
## Load the weights from the latest checkpoint
126+
def load_weights_from_latest_checkpoint(model):
127+
file_list = os.listdir(args.checkpoint_local_dir)
128+
logger.info('GPU # {} :: Checking for checkpoint files...'.format(sdp.rank()))
129+
if len(file_list) > 0:
130+
logger.info('GPU # {} :: Checkpoint files found.'.format(sdp.rank()))
131+
logger.info('GPU # {} :: Loading the weights from the latest model checkpoint...'.format(sdp.rank()))
132+
model.load_weights(tf.train.latest_checkpoint(args.checkpoint_local_dir))
133+
logger.info('GPU # {} :: Completed loading weights from the latest model checkpoint.'.format(sdp.rank()))
134+
else:
135+
logger.info('GPU # {} :: Checkpoint files not found.'.format(sdp.rank()))
136+
137+
138+
## Compile the model by setting the optimizer, loss function and metrics
139+
def compile_model(model, learning_rate, decay):
140+
logger.info('GPU # {} :: Compiling the model...'.format(sdp.rank()))
141+
# Instantiate the optimizer
142+
optimizer = Adam(learning_rate=learning_rate, decay=decay)
143+
# Instantiate the loss function
144+
loss_fn = SparseCategoricalCrossentropy(from_logits=True)
145+
# Prepare the metrics
146+
train_acc_metric = SparseCategoricalAccuracy()
147+
val_acc_metric = SparseCategoricalAccuracy()
148+
# Compile the model
149+
model.compile(optimizer=optimizer,
150+
loss=loss_fn,
151+
metrics=[train_acc_metric])
152+
logger.info('GPU # {} :: Completed compiling the model.'.format(sdp.rank()))
153+
return optimizer, loss_fn, train_acc_metric, val_acc_metric
154+
155+
156+
## Prepare the batch datasets
157+
def prepare_batch_datasets(x_train, y_train, batch_size):
158+
logger.info('GPU # {} :: Preparing train and validation datasets for batches...'.format(sdp.rank()))
159+
# Reserve the required samples for validation
160+
x_val = x_train[-(len(x_train) * int(VALIDATION_DATA_SPLIT)):]
161+
y_val = y_train[-(len(y_train) * int(VALIDATION_DATA_SPLIT)):]
162+
# Prepare the training dataset with shuffling
163+
train_dataset = Dataset.from_tensor_slices((x_train, y_train))
164+
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
165+
# Prepare the validation dataset
166+
val_dataset = Dataset.from_tensor_slices((x_val, y_val))
167+
val_dataset = val_dataset.batch(batch_size)
168+
logger.info('GPU # {} :: Completed preparing train and validation datasets for batches.'.format(sdp.rank()))
169+
return x_val, y_val, train_dataset, val_dataset
170+
171+
172+
## Define the training step
173+
@tf.function
174+
def training_step(model, x_batch_train, y_batch_train, optimizer, loss_fn, train_acc_metric, is_first_batch):
175+
# Open a GradientTape to record the operations run
176+
# during the forward pass, which enables auto-differentiation
177+
with tf.GradientTape() as tape:
178+
# Run the forward pass of the layer
179+
logits = model(x_batch_train, training=True)
180+
# Compute the loss value
181+
loss_value = loss_fn(y_batch_train, logits)
182+
# SMDataParallel: Wrap tf.GradientTape with SMDataParallel's DistributedGradientTape
183+
tape = sdp.DistributedGradientTape(tape)
184+
# Retrieve the gradients of the trainable variables with respect to the loss
185+
grads = tape.gradient(loss_value, model.trainable_weights)
186+
# Run one step of gradient descent by updating
187+
# the value of the variables to minimize the loss
188+
optimizer.apply_gradients(zip(grads, model.trainable_weights))
189+
# Perform speicific SMDataParallel on the first batch
190+
if is_first_batch:
191+
# SMDataParallel: Broadcast model and optimizer variables
192+
sdp.broadcast_variables(model.variables, root_rank=0)
193+
sdp.broadcast_variables(optimizer.variables(), root_rank=0)
194+
# Update training metric
195+
train_acc_metric.update_state(y_batch_train, logits)
196+
# SMDataParallel: all_reduce call
197+
loss_value = sdp.oob_allreduce(loss_value) # Average the loss across workers
198+
return loss_value
199+
200+
201+
## Define the validation step
202+
@tf.function
203+
def validation_step(model, x_batch_val, y_batch_val, val_acc_metric):
204+
val_logits = model(x_batch_val, training=False)
205+
val_acc_metric.update_state(y_batch_val, val_logits)
206+
207+
208+
## Perform validation
209+
def perform_validation(model, val_dataset, val_acc_metric):
210+
logger.debug('GPU # {} :: Performing validation...'.format(sdp.rank()))
211+
for x_batch_val, y_batch_val in val_dataset:
212+
validation_step(model, x_batch_val, y_batch_val, val_acc_metric)
213+
logger.debug('GPU # {} :: Completed performing validation.'.format(sdp.rank()))
214+
return val_acc_metric.result()
215+
216+
## Save the model as a checkpoint
217+
def save_checkpoint(checkpoint):
218+
logger.debug('GPU # {} :: Saving model checkpoint...'.format(sdp.rank()))
219+
checkpoint.save(os.path.join(args.checkpoint_local_dir, 'tf2-checkpoint'))
220+
logger.info('GPU # {} :: Checkpoint counter = {}'.format(sdp.rank(), checkpoint.save_counter.numpy()))
221+
logger.debug('GPU # {} :: Completed saving model checkpoint.'.format(sdp.rank()))
222+
223+
224+
## Train the model
225+
def train_model(model, model_dir, x_train, y_train, batch_size, epochs, learning_rate, decay):
226+
history = []
227+
228+
# SMDataParallel: Scale learning rate
229+
learning_rate = learning_rate * sdp.size()
230+
231+
# Compile the model
232+
optimizer, loss_fn, train_acc_metric, val_acc_metric = compile_model(model, learning_rate, decay)
233+
234+
# SMDataParallel: Initialize to perform checkpointing only from leader node
235+
if sdp.rank() == 0:
236+
if args.checkpoint_enabled.lower() == 'true':
237+
# Create the checkpoint object
238+
checkpoint = Checkpoint(model)
239+
240+
# Prepare the batch datasets
241+
x_val, y_val, train_dataset, val_dataset = prepare_batch_datasets(x_train, y_train, batch_size)
242+
243+
# Perform training
244+
logger.info('GPU # {} :: Training the model...'.format(sdp.rank()))
245+
training_start_time = time.time()
246+
logger.debug('GPU # {} :: Iterating over epochs...'.format(sdp.rank()))
247+
# Iterate over epochs
248+
for epoch in range(epochs):
249+
logger.debug('Starting epoch {}...'.format(sdp.rank(), int(epoch) + 1))
250+
epoch_start_time = time.time()
251+
252+
# Iterate over the batches of the dataset
253+
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
254+
logger.debug('GPU # {} :: Running training step {}...'.format(sdp.rank(), int(step) + 1))
255+
loss_value = training_step(model, x_batch_train, y_batch_train, optimizer, loss_fn,
256+
train_acc_metric, step == 0)
257+
logger.debug('GPU # {} :: Training loss in step = {}'.format(sdp.rank(), loss_value))
258+
logger.debug('GPU # {} :: Completed running training step {}.'.format(sdp.rank(), int(step) + 1))
259+
260+
# SMDataParallel: Perform validation only from leader node
261+
if sdp.rank() == 0:
262+
# Perform validation and save metrics at the end of each epoch
263+
history.append([int(epoch) + 1, train_acc_metric.result(),
264+
perform_validation(model, val_dataset, val_acc_metric)])
265+
266+
# Reset metrics
267+
train_acc_metric.reset_states()
268+
val_acc_metric.reset_states()
269+
270+
# SMDataParallel: Perform model checkpointing only from leader node
271+
if sdp.rank() == 0:
272+
if args.checkpoint_enabled.lower() == 'true':
273+
# Save the model as a checkpoint
274+
save_checkpoint(checkpoint)
275+
276+
epoch_end_time = time.time()
277+
# SMDataParallel: Print epoch time only from leader node
278+
if sdp.rank() == 0:
279+
logger.debug('Epoch duration (primary node) = %.2f second(s)' % (epoch_end_time - epoch_start_time))
280+
logger.debug('GPU # {} :: Completed epoch {}.'.format(sdp.rank(), int(epoch) + 1))
281+
282+
logger.debug('GPU # {} :: Completed iterating over epochs.'.format(sdp.rank()))
283+
training_end_time = time.time()
284+
# SMDataParallel: Print training time and result only from leader node
285+
if sdp.rank() == 0:
286+
logger.info('Training duration (primary node) = %.2f second(s)' % (training_end_time - training_start_time))
287+
print_training_result(history)
288+
logger.info('GPU # {} :: Completed training the model.'.format(sdp.rank()))
289+
290+
291+
## Print training result
292+
def print_training_result(history):
293+
output_table_string_list = []
294+
output_table_string_list.append('\n')
295+
output_table_string_list.append("{:<10} {:<25} {:<25}".format('Epoch', 'Accuracy', 'Validation Accuracy'))
296+
output_table_string_list.append('\n')
297+
size = len(history)
298+
for index in range(size):
299+
record = history[index]
300+
output_table_string_list.append("{:<10} {:<25} {:<25}".format(record[0], record[1], record[2]))
301+
output_table_string_list.append('\n')
302+
output_table_string_list.append('\n')
303+
logger.info(''.join(output_table_string_list))
304+
305+
306+
## Evaluate the model
307+
def evaluate_model(model, x_test, y_test):
308+
logger.info('GPU # {} :: Evaluating the model...'.format(sdp.rank()))
309+
test_loss, test_accuracy = model.evaluate(x_test, y_test,
310+
verbose=EVALUATE_VERBOSE_LEVEL)
311+
logger.info('GPU # {} :: Test loss = {}'.format(sdp.rank(), test_loss))
312+
logger.info('GPU # {} :: Test accuracy = {}'.format(sdp.rank(), test_accuracy))
313+
logger.info('GPU # {} :: Completed evaluating the model.'.format(sdp.rank()))
314+
return test_loss, test_accuracy
315+
316+
317+
## Save the model
318+
def save_model(model, model_dir):
319+
logger.info('GPU # {} :: Saving the model...'.format(sdp.rank()))
320+
tf.saved_model.save(model, model_dir)
321+
logger.info('GPU # {} :: Completed saving the model.'.format(sdp.rank()))
322+
323+
324+
## The main function
325+
if __name__ == "__main__":
326+
logger.info('Executing the main() function...')
327+
# Parse command-line arguments
328+
args, _ = parse_args()
329+
# Initialize the SMDataParallel environment
330+
init_sdp()
331+
# Log version info
332+
logger.info('GPU # {} :: TensorFlow version : {}'.format(sdp.rank(), tf.__version__))
333+
logger.info('GPU # {} :: SMDistributedDataParallel version : {}'.format(sdp.rank(), smdistributed.dataparallel.__version__))
334+
# Load train and test data
335+
x_train, y_train = load_and_preprocess_data('training', args.train, 'x_train.npy', 'y_train.npy')
336+
x_test, y_test = load_and_preprocess_data('test', args.test, 'x_test.npy', 'y_test.npy')
337+
# Create, train and evaluate the model
338+
model = create_model()
339+
if args.checkpoint_load_previous.lower() == 'true':
340+
load_weights_from_latest_checkpoint(model)
341+
train_model(model, args.model_dir, x_train, y_train, args.batch_size, args.epochs, args.learning_rate, args.decay)
342+
# SMDataParallel: Evaluate and save model only from leader node
343+
if sdp.rank() == 0:
344+
# Evaluate the generated model
345+
evaluate_model(model, x_test, y_test)
346+
# Save the generated model
347+
save_model(model, args.model_dir)
348+
logger.info('Completed executing the main() function.')

0 commit comments

Comments
 (0)