|
| 1 | +""" |
| 2 | +Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 3 | +SPDX-License-Identifier: MIT-0 |
| 4 | +""" |
| 5 | +import argparse |
| 6 | +import numpy as np |
| 7 | +import os |
| 8 | +import logging |
| 9 | +import time |
| 10 | +import tensorflow as tf |
| 11 | +import smdistributed.dataparallel |
| 12 | +import smdistributed.dataparallel.tensorflow as sdp |
| 13 | +import tensorflow.config.experimental as exp |
| 14 | +from tensorflow.data import Dataset |
| 15 | +from tensorflow.keras.layers import (Input, Conv2D, MaxPooling2D, Flatten, |
| 16 | + Dense, Dropout, BatchNormalization) |
| 17 | +from tensorflow.keras.models import Sequential |
| 18 | +from tensorflow.keras.optimizers import Adam |
| 19 | +from tensorflow.keras.callbacks import ModelCheckpoint |
| 20 | +from tensorflow.keras.losses import SparseCategoricalCrossentropy |
| 21 | +from tensorflow.keras.metrics import SparseCategoricalAccuracy |
| 22 | +from tensorflow.train import Checkpoint |
| 23 | + |
| 24 | +# Declare constants |
| 25 | +TRAIN_VERBOSE_LEVEL = 0 |
| 26 | +EVALUATE_VERBOSE_LEVEL = 0 |
| 27 | +IMAGE_HEIGHT, IMAGE_WIDTH, NUM_CHANNELS, NUM_CLASSES = 28, 28, 1, 10 |
| 28 | +VALIDATION_DATA_SPLIT = 0.1 |
| 29 | + |
| 30 | +# Create the logger |
| 31 | +logger = logging.getLogger(__name__) |
| 32 | +logger.setLevel(int(os.environ.get('SM_LOG_LEVEL', logging.INFO))) |
| 33 | + |
| 34 | + |
| 35 | +## Parse and load the command-line arguments sent to the script |
| 36 | +## These will be sent by SageMaker when it launches the training container |
| 37 | +def parse_args(): |
| 38 | + logger.info('Parsing command-line arguments...') |
| 39 | + parser = argparse.ArgumentParser() |
| 40 | + # Hyperparameters |
| 41 | + parser.add_argument('--epochs', type=int, default=1) |
| 42 | + parser.add_argument('--batch_size', type=int, default=64) |
| 43 | + parser.add_argument('--learning_rate', type=float, default=0.1) |
| 44 | + parser.add_argument('--decay', type=float, default=1e-6) |
| 45 | + # Data directories |
| 46 | + parser.add_argument('--train', type=str, default=os.environ.get('SM_CHANNEL_TRAIN')) |
| 47 | + parser.add_argument('--test', type=str, default=os.environ.get('SM_CHANNEL_TEST')) |
| 48 | + # Model output directory |
| 49 | + parser.add_argument('--model_dir', type=str, default=os.environ.get('SM_MODEL_DIR')) |
| 50 | + # Checkpoint info |
| 51 | + parser.add_argument('--checkpoint_enabled', type=str, default='True') |
| 52 | + parser.add_argument('--checkpoint_load_previous', type=str, default='True') |
| 53 | + parser.add_argument('--checkpoint_local_dir', type=str, default='/opt/ml/checkpoints/') |
| 54 | + logger.info('Completed parsing command-line arguments.') |
| 55 | + return parser.parse_known_args() |
| 56 | + |
| 57 | + |
| 58 | +## Initialize the SMDataParallel environment |
| 59 | +def init_sdp(): |
| 60 | + logger.info('Initializing the SMDataParallel environment...') |
| 61 | + tf.random.set_seed(42) |
| 62 | + sdp.init() |
| 63 | + logger.debug('Getting GPU list...') |
| 64 | + gpus = exp.list_physical_devices('GPU') |
| 65 | + logger.debug('Number of GPUs = {}'.format(len(gpus))) |
| 66 | + logger.debug('Completed getting GPU list.') |
| 67 | + logger.debug('Enabling memory growth on all GPUs...') |
| 68 | + for gpu in gpus: |
| 69 | + exp.set_memory_growth(gpu, True) |
| 70 | + logger.debug('Completed enabling memory growth on all GPUs.') |
| 71 | + logger.debug('Pinning GPUs to a single SMDataParallel process...') |
| 72 | + if gpus: |
| 73 | + exp.set_visible_devices(gpus[sdp.local_rank()], 'GPU') |
| 74 | + logger.debug('Completed pinning GPUs to a single SMDataParallel process.') |
| 75 | + logger.info('Completed initializing the SMDataParallel environment.') |
| 76 | + |
| 77 | + |
| 78 | +## Load data from local directory to memory and preprocess |
| 79 | +def load_and_preprocess_data(data_type, data_dir, x_data_file_name, y_data_file_name): |
| 80 | + logger.info('GPU # {} :: Loading and preprocessing {} data...'.format(sdp.rank(), data_type)) |
| 81 | + x_data = np.load(os.path.join(data_dir, x_data_file_name)) |
| 82 | + x_data = np.reshape(x_data, (x_data.shape[0], x_data.shape[1], x_data.shape[2], 1)) |
| 83 | + y_data = np.load(os.path.join(data_dir, y_data_file_name)) |
| 84 | + logger.info('GPU # {} :: Completed loading and preprocessing {} data.'.format(sdp.rank(), data_type)) |
| 85 | + return x_data, y_data |
| 86 | + |
| 87 | + |
| 88 | +## Construct the network |
| 89 | +def create_model(): |
| 90 | + logger.info('GPU # {} :: Creating the model...'.format(sdp.rank())) |
| 91 | + model = Sequential([ |
| 92 | + Conv2D(64, kernel_size=(3, 3), activation='relu', padding='same', |
| 93 | + input_shape=(IMAGE_HEIGHT, IMAGE_WIDTH, NUM_CHANNELS)), |
| 94 | + BatchNormalization(), |
| 95 | + Conv2D(64, kernel_size=(3, 3), activation='relu'), |
| 96 | + BatchNormalization(), |
| 97 | + MaxPooling2D(pool_size=(2, 2)), |
| 98 | + |
| 99 | + Conv2D(128, kernel_size=(3, 3), activation='relu', padding='same'), |
| 100 | + BatchNormalization(), |
| 101 | + Conv2D(128, kernel_size=(3, 3), activation='relu'), |
| 102 | + BatchNormalization(), |
| 103 | + MaxPooling2D(pool_size=(2, 2)), |
| 104 | + |
| 105 | + Conv2D(256, kernel_size=(3, 3), activation='relu', padding='same'), |
| 106 | + BatchNormalization(), |
| 107 | + Conv2D(256, kernel_size=(3, 3), activation='relu'), |
| 108 | + BatchNormalization(), |
| 109 | + MaxPooling2D(pool_size=(2, 2)), |
| 110 | + |
| 111 | + Flatten(), |
| 112 | + |
| 113 | + Dense(1024, activation='relu'), |
| 114 | + |
| 115 | + Dense(512, activation='relu'), |
| 116 | + |
| 117 | + Dense(NUM_CLASSES, activation='softmax') |
| 118 | + ]) |
| 119 | + # Print the model summary |
| 120 | + logger.info(model.summary()) |
| 121 | + logger.info('GPU # {} :: Completed creating the model.'.format(sdp.rank())) |
| 122 | + return model |
| 123 | + |
| 124 | + |
| 125 | +## Load the weights from the latest checkpoint |
| 126 | +def load_weights_from_latest_checkpoint(model): |
| 127 | + file_list = os.listdir(args.checkpoint_local_dir) |
| 128 | + logger.info('GPU # {} :: Checking for checkpoint files...'.format(sdp.rank())) |
| 129 | + if len(file_list) > 0: |
| 130 | + logger.info('GPU # {} :: Checkpoint files found.'.format(sdp.rank())) |
| 131 | + logger.info('GPU # {} :: Loading the weights from the latest model checkpoint...'.format(sdp.rank())) |
| 132 | + model.load_weights(tf.train.latest_checkpoint(args.checkpoint_local_dir)) |
| 133 | + logger.info('GPU # {} :: Completed loading weights from the latest model checkpoint.'.format(sdp.rank())) |
| 134 | + else: |
| 135 | + logger.info('GPU # {} :: Checkpoint files not found.'.format(sdp.rank())) |
| 136 | + |
| 137 | + |
| 138 | +## Compile the model by setting the optimizer, loss function and metrics |
| 139 | +def compile_model(model, learning_rate, decay): |
| 140 | + logger.info('GPU # {} :: Compiling the model...'.format(sdp.rank())) |
| 141 | + # Instantiate the optimizer |
| 142 | + optimizer = Adam(learning_rate=learning_rate, decay=decay) |
| 143 | + # Instantiate the loss function |
| 144 | + loss_fn = SparseCategoricalCrossentropy(from_logits=True) |
| 145 | + # Prepare the metrics |
| 146 | + train_acc_metric = SparseCategoricalAccuracy() |
| 147 | + val_acc_metric = SparseCategoricalAccuracy() |
| 148 | + # Compile the model |
| 149 | + model.compile(optimizer=optimizer, |
| 150 | + loss=loss_fn, |
| 151 | + metrics=[train_acc_metric]) |
| 152 | + logger.info('GPU # {} :: Completed compiling the model.'.format(sdp.rank())) |
| 153 | + return optimizer, loss_fn, train_acc_metric, val_acc_metric |
| 154 | + |
| 155 | + |
| 156 | +## Prepare the batch datasets |
| 157 | +def prepare_batch_datasets(x_train, y_train, batch_size): |
| 158 | + logger.info('GPU # {} :: Preparing train and validation datasets for batches...'.format(sdp.rank())) |
| 159 | + # Reserve the required samples for validation |
| 160 | + x_val = x_train[-(len(x_train) * int(VALIDATION_DATA_SPLIT)):] |
| 161 | + y_val = y_train[-(len(y_train) * int(VALIDATION_DATA_SPLIT)):] |
| 162 | + # Prepare the training dataset with shuffling |
| 163 | + train_dataset = Dataset.from_tensor_slices((x_train, y_train)) |
| 164 | + train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size) |
| 165 | + # Prepare the validation dataset |
| 166 | + val_dataset = Dataset.from_tensor_slices((x_val, y_val)) |
| 167 | + val_dataset = val_dataset.batch(batch_size) |
| 168 | + logger.info('GPU # {} :: Completed preparing train and validation datasets for batches.'.format(sdp.rank())) |
| 169 | + return x_val, y_val, train_dataset, val_dataset |
| 170 | + |
| 171 | + |
| 172 | +## Define the training step |
| 173 | +@tf.function |
| 174 | +def training_step(model, x_batch_train, y_batch_train, optimizer, loss_fn, train_acc_metric, is_first_batch): |
| 175 | + # Open a GradientTape to record the operations run |
| 176 | + # during the forward pass, which enables auto-differentiation |
| 177 | + with tf.GradientTape() as tape: |
| 178 | + # Run the forward pass of the layer |
| 179 | + logits = model(x_batch_train, training=True) |
| 180 | + # Compute the loss value |
| 181 | + loss_value = loss_fn(y_batch_train, logits) |
| 182 | + # SMDataParallel: Wrap tf.GradientTape with SMDataParallel's DistributedGradientTape |
| 183 | + tape = sdp.DistributedGradientTape(tape) |
| 184 | + # Retrieve the gradients of the trainable variables with respect to the loss |
| 185 | + grads = tape.gradient(loss_value, model.trainable_weights) |
| 186 | + # Run one step of gradient descent by updating |
| 187 | + # the value of the variables to minimize the loss |
| 188 | + optimizer.apply_gradients(zip(grads, model.trainable_weights)) |
| 189 | + # Perform speicific SMDataParallel on the first batch |
| 190 | + if is_first_batch: |
| 191 | + # SMDataParallel: Broadcast model and optimizer variables |
| 192 | + sdp.broadcast_variables(model.variables, root_rank=0) |
| 193 | + sdp.broadcast_variables(optimizer.variables(), root_rank=0) |
| 194 | + # Update training metric |
| 195 | + train_acc_metric.update_state(y_batch_train, logits) |
| 196 | + # SMDataParallel: all_reduce call |
| 197 | + loss_value = sdp.oob_allreduce(loss_value) # Average the loss across workers |
| 198 | + return loss_value |
| 199 | + |
| 200 | + |
| 201 | +## Define the validation step |
| 202 | +@tf.function |
| 203 | +def validation_step(model, x_batch_val, y_batch_val, val_acc_metric): |
| 204 | + val_logits = model(x_batch_val, training=False) |
| 205 | + val_acc_metric.update_state(y_batch_val, val_logits) |
| 206 | + |
| 207 | + |
| 208 | +## Perform validation |
| 209 | +def perform_validation(model, val_dataset, val_acc_metric): |
| 210 | + logger.debug('GPU # {} :: Performing validation...'.format(sdp.rank())) |
| 211 | + for x_batch_val, y_batch_val in val_dataset: |
| 212 | + validation_step(model, x_batch_val, y_batch_val, val_acc_metric) |
| 213 | + logger.debug('GPU # {} :: Completed performing validation.'.format(sdp.rank())) |
| 214 | + return val_acc_metric.result() |
| 215 | + |
| 216 | +## Save the model as a checkpoint |
| 217 | +def save_checkpoint(checkpoint): |
| 218 | + logger.debug('GPU # {} :: Saving model checkpoint...'.format(sdp.rank())) |
| 219 | + checkpoint.save(os.path.join(args.checkpoint_local_dir, 'tf2-checkpoint')) |
| 220 | + logger.info('GPU # {} :: Checkpoint counter = {}'.format(sdp.rank(), checkpoint.save_counter.numpy())) |
| 221 | + logger.debug('GPU # {} :: Completed saving model checkpoint.'.format(sdp.rank())) |
| 222 | + |
| 223 | + |
| 224 | +## Train the model |
| 225 | +def train_model(model, model_dir, x_train, y_train, batch_size, epochs, learning_rate, decay): |
| 226 | + history = [] |
| 227 | + |
| 228 | + # SMDataParallel: Scale learning rate |
| 229 | + learning_rate = learning_rate * sdp.size() |
| 230 | + |
| 231 | + # Compile the model |
| 232 | + optimizer, loss_fn, train_acc_metric, val_acc_metric = compile_model(model, learning_rate, decay) |
| 233 | + |
| 234 | + # SMDataParallel: Initialize to perform checkpointing only from leader node |
| 235 | + if sdp.rank() == 0: |
| 236 | + if args.checkpoint_enabled.lower() == 'true': |
| 237 | + # Create the checkpoint object |
| 238 | + checkpoint = Checkpoint(model) |
| 239 | + |
| 240 | + # Prepare the batch datasets |
| 241 | + x_val, y_val, train_dataset, val_dataset = prepare_batch_datasets(x_train, y_train, batch_size) |
| 242 | + |
| 243 | + # Perform training |
| 244 | + logger.info('GPU # {} :: Training the model...'.format(sdp.rank())) |
| 245 | + training_start_time = time.time() |
| 246 | + logger.debug('GPU # {} :: Iterating over epochs...'.format(sdp.rank())) |
| 247 | + # Iterate over epochs |
| 248 | + for epoch in range(epochs): |
| 249 | + logger.debug('Starting epoch {}...'.format(sdp.rank(), int(epoch) + 1)) |
| 250 | + epoch_start_time = time.time() |
| 251 | + |
| 252 | + # Iterate over the batches of the dataset |
| 253 | + for step, (x_batch_train, y_batch_train) in enumerate(train_dataset): |
| 254 | + logger.debug('GPU # {} :: Running training step {}...'.format(sdp.rank(), int(step) + 1)) |
| 255 | + loss_value = training_step(model, x_batch_train, y_batch_train, optimizer, loss_fn, |
| 256 | + train_acc_metric, step == 0) |
| 257 | + logger.debug('GPU # {} :: Training loss in step = {}'.format(sdp.rank(), loss_value)) |
| 258 | + logger.debug('GPU # {} :: Completed running training step {}.'.format(sdp.rank(), int(step) + 1)) |
| 259 | + |
| 260 | + # SMDataParallel: Perform validation only from leader node |
| 261 | + if sdp.rank() == 0: |
| 262 | + # Perform validation and save metrics at the end of each epoch |
| 263 | + history.append([int(epoch) + 1, train_acc_metric.result(), |
| 264 | + perform_validation(model, val_dataset, val_acc_metric)]) |
| 265 | + |
| 266 | + # Reset metrics |
| 267 | + train_acc_metric.reset_states() |
| 268 | + val_acc_metric.reset_states() |
| 269 | + |
| 270 | + # SMDataParallel: Perform model checkpointing only from leader node |
| 271 | + if sdp.rank() == 0: |
| 272 | + if args.checkpoint_enabled.lower() == 'true': |
| 273 | + # Save the model as a checkpoint |
| 274 | + save_checkpoint(checkpoint) |
| 275 | + |
| 276 | + epoch_end_time = time.time() |
| 277 | + # SMDataParallel: Print epoch time only from leader node |
| 278 | + if sdp.rank() == 0: |
| 279 | + logger.debug('Epoch duration (primary node) = %.2f second(s)' % (epoch_end_time - epoch_start_time)) |
| 280 | + logger.debug('GPU # {} :: Completed epoch {}.'.format(sdp.rank(), int(epoch) + 1)) |
| 281 | + |
| 282 | + logger.debug('GPU # {} :: Completed iterating over epochs.'.format(sdp.rank())) |
| 283 | + training_end_time = time.time() |
| 284 | + # SMDataParallel: Print training time and result only from leader node |
| 285 | + if sdp.rank() == 0: |
| 286 | + logger.info('Training duration (primary node) = %.2f second(s)' % (training_end_time - training_start_time)) |
| 287 | + print_training_result(history) |
| 288 | + logger.info('GPU # {} :: Completed training the model.'.format(sdp.rank())) |
| 289 | + |
| 290 | + |
| 291 | +## Print training result |
| 292 | +def print_training_result(history): |
| 293 | + output_table_string_list = [] |
| 294 | + output_table_string_list.append('\n') |
| 295 | + output_table_string_list.append("{:<10} {:<25} {:<25}".format('Epoch', 'Accuracy', 'Validation Accuracy')) |
| 296 | + output_table_string_list.append('\n') |
| 297 | + size = len(history) |
| 298 | + for index in range(size): |
| 299 | + record = history[index] |
| 300 | + output_table_string_list.append("{:<10} {:<25} {:<25}".format(record[0], record[1], record[2])) |
| 301 | + output_table_string_list.append('\n') |
| 302 | + output_table_string_list.append('\n') |
| 303 | + logger.info(''.join(output_table_string_list)) |
| 304 | + |
| 305 | + |
| 306 | +## Evaluate the model |
| 307 | +def evaluate_model(model, x_test, y_test): |
| 308 | + logger.info('GPU # {} :: Evaluating the model...'.format(sdp.rank())) |
| 309 | + test_loss, test_accuracy = model.evaluate(x_test, y_test, |
| 310 | + verbose=EVALUATE_VERBOSE_LEVEL) |
| 311 | + logger.info('GPU # {} :: Test loss = {}'.format(sdp.rank(), test_loss)) |
| 312 | + logger.info('GPU # {} :: Test accuracy = {}'.format(sdp.rank(), test_accuracy)) |
| 313 | + logger.info('GPU # {} :: Completed evaluating the model.'.format(sdp.rank())) |
| 314 | + return test_loss, test_accuracy |
| 315 | + |
| 316 | + |
| 317 | +## Save the model |
| 318 | +def save_model(model, model_dir): |
| 319 | + logger.info('GPU # {} :: Saving the model...'.format(sdp.rank())) |
| 320 | + tf.saved_model.save(model, model_dir) |
| 321 | + logger.info('GPU # {} :: Completed saving the model.'.format(sdp.rank())) |
| 322 | + |
| 323 | + |
| 324 | +## The main function |
| 325 | +if __name__ == "__main__": |
| 326 | + logger.info('Executing the main() function...') |
| 327 | + # Parse command-line arguments |
| 328 | + args, _ = parse_args() |
| 329 | + # Initialize the SMDataParallel environment |
| 330 | + init_sdp() |
| 331 | + # Log version info |
| 332 | + logger.info('GPU # {} :: TensorFlow version : {}'.format(sdp.rank(), tf.__version__)) |
| 333 | + logger.info('GPU # {} :: SMDistributedDataParallel version : {}'.format(sdp.rank(), smdistributed.dataparallel.__version__)) |
| 334 | + # Load train and test data |
| 335 | + x_train, y_train = load_and_preprocess_data('training', args.train, 'x_train.npy', 'y_train.npy') |
| 336 | + x_test, y_test = load_and_preprocess_data('test', args.test, 'x_test.npy', 'y_test.npy') |
| 337 | + # Create, train and evaluate the model |
| 338 | + model = create_model() |
| 339 | + if args.checkpoint_load_previous.lower() == 'true': |
| 340 | + load_weights_from_latest_checkpoint(model) |
| 341 | + train_model(model, args.model_dir, x_train, y_train, args.batch_size, args.epochs, args.learning_rate, args.decay) |
| 342 | + # SMDataParallel: Evaluate and save model only from leader node |
| 343 | + if sdp.rank() == 0: |
| 344 | + # Evaluate the generated model |
| 345 | + evaluate_model(model, x_test, y_test) |
| 346 | + # Save the generated model |
| 347 | + save_model(model, args.model_dir) |
| 348 | + logger.info('Completed executing the main() function.') |
0 commit comments