From 6caed081eaec49f2ff9f0f7d25a506b34ef7f2a4 Mon Sep 17 00:00:00 2001 From: firestonelib Date: Sat, 17 Sep 2022 10:20:45 +0000 Subject: [PATCH 01/10] add super res 512 and 1024 --- .../imagen/imagen_super_resolusion_1024.yaml | 56 +++++++++++++++++++ .../imagen/imagen_super_resolusion_256.yaml | 56 +++++++++++++++++++ .../imagen/imagen_super_resolusion_512.yaml | 56 +++++++++++++++++++ ppfleetx/core/engine/eager_engine.py | 2 +- .../multimodal_model/imagen/__init__.py | 4 +- .../multimodal_model/imagen/modeling.py | 8 +++ .../multimodal_model/multimodal_module.py | 3 +- .../run_super_resolusion_1024_single.sh | 18 ++++++ .../imagen/run_super_resolusion_512_single.sh | 18 ++++++ .../imagen/run_text2im_397M_64x64_single.sh | 2 +- 10 files changed, 218 insertions(+), 5 deletions(-) create mode 100644 ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml create mode 100644 ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_256.yaml create mode 100644 ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_512.yaml create mode 100644 projects/imagen/run_super_resolusion_1024_single.sh create mode 100644 projects/imagen/run_super_resolusion_512_single.sh diff --git a/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml b/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml new file mode 100644 index 000000000..658d03ad4 --- /dev/null +++ b/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml @@ -0,0 +1,56 @@ +_base_: ./imagen_base.yaml + +Global: + global_batch_size: + local_batch_size: 1 + micro_batch_size: 1 + + +Model: + name: imagen_SR1024 + text_encoder_name: t5/t5-11b + text_embed_dim: 1024 + timesteps: 1000 + in_chans: 3 + cond_drop_prob: 0.1 + noise_schedules: cosine + pred_objectives: noise + lowres_noise_schedule: linear + lowres_sample_noise_level: 0.2 + per_sample_random_aug_noise_level: False + condition_on_text: True + auto_normalize_img: True + p2_loss_weight_gamma: 0.5 + dynamic_thresholding: True, + dynamic_thresholding_percentile: 0.95 + only_train_unet_number: 1 + use_recompute: False + +Data: + Train: + dataset: + name: ImagenDataset + input_path: ./data/cc12m_base64.lst + shuffle: True + input_resolusion: 1024 + max_seq_len: 128 + loader: + num_workers: 8 + shuffle: True + batch_size: 1 + drop_last: True + collate_fn: imagen_collate_fn + + +Loss: + name: mse_loss + p2_loss_weight_k: 1.0 + +Distributed: + dp_degree: 1 + mp_degree: 1 + pp_degree: 1 + sharding: + sharding_degree: 1 + sharding_stage: 1 + sharding_offload: False diff --git a/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_256.yaml b/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_256.yaml new file mode 100644 index 000000000..5929527ed --- /dev/null +++ b/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_256.yaml @@ -0,0 +1,56 @@ +_base_: ./imagen_base.yaml + +Global: + global_batch_size: + local_batch_size: 1 + micro_batch_size: 1 + + +Model: + name: imagen_SR256 + text_encoder_name: t5/t5-11b + text_embed_dim: 1024 + timesteps: 1000 + in_chans: 3 + cond_drop_prob: 0.1 + noise_schedules: cosine + pred_objectives: noise + lowres_noise_schedule: linear + lowres_sample_noise_level: 0.2 + per_sample_random_aug_noise_level: False + condition_on_text: True + auto_normalize_img: True + p2_loss_weight_gamma: 0.5 + dynamic_thresholding: True, + dynamic_thresholding_percentile: 0.95 + only_train_unet_number: 1 + use_recompute: False + +Data: + Train: + dataset: + name: ImagenDataset + input_path: ./data/cc12m_base64.lst + shuffle: True + input_resolusion: 256 + max_seq_len: 128 + loader: + num_workers: 8 + shuffle: True + batch_size: 1 + drop_last: True + collate_fn: imagen_collate_fn + + +Loss: + name: mse_loss + p2_loss_weight_k: 1.0 + +Distributed: + dp_degree: 1 + mp_degree: 1 + pp_degree: 1 + sharding: + sharding_degree: 1 + sharding_stage: 1 + sharding_offload: False diff --git a/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_512.yaml b/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_512.yaml new file mode 100644 index 000000000..c1c19e84c --- /dev/null +++ b/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_512.yaml @@ -0,0 +1,56 @@ +_base_: ./imagen_base.yaml + +Global: + global_batch_size: + local_batch_size: 1 + micro_batch_size: 1 + + +Model: + name: imagen_SR512 + text_encoder_name: t5/t5-11b + text_embed_dim: 1024 + timesteps: 1000 + in_chans: 3 + cond_drop_prob: 0.1 + noise_schedules: cosine + pred_objectives: noise + lowres_noise_schedule: linear + lowres_sample_noise_level: 0.2 + per_sample_random_aug_noise_level: False + condition_on_text: True + auto_normalize_img: True + p2_loss_weight_gamma: 0.5 + dynamic_thresholding: True, + dynamic_thresholding_percentile: 0.95 + only_train_unet_number: 1 + use_recompute: False + +Data: + Train: + dataset: + name: ImagenDataset + input_path: ./data/cc12m_base64.lst + shuffle: True + input_resolusion: 512 + max_seq_len: 128 + loader: + num_workers: 8 + shuffle: True + batch_size: 1 + drop_last: True + collate_fn: imagen_collate_fn + + +Loss: + name: mse_loss + p2_loss_weight_k: 1.0 + +Distributed: + dp_degree: 1 + mp_degree: 1 + pp_degree: 1 + sharding: + sharding_degree: 1 + sharding_stage: 1 + sharding_offload: False diff --git a/ppfleetx/core/engine/eager_engine.py b/ppfleetx/core/engine/eager_engine.py index 671c2ea8b..0a461d411 100644 --- a/ppfleetx/core/engine/eager_engine.py +++ b/ppfleetx/core/engine/eager_engine.py @@ -259,7 +259,7 @@ def _train_one_epoch(self, # Note(GuoxiaWang): Do not use len(train_data_loader()), # it will cause a memory leak. total_train_batch = len(train_data_loader) - total_eval_batch = len(valid_data_loader) + total_eval_batch = len(valid_data_loader) if valid_data_loader is not None else 0 for step, batch in enumerate(train_data_loader): if epoch_index == self._load_recovery['epoch']: diff --git a/ppfleetx/models/multimodal_model/imagen/__init__.py b/ppfleetx/models/multimodal_model/imagen/__init__.py index 8bafdde60..d0cb2f4f0 100644 --- a/ppfleetx/models/multimodal_model/imagen/__init__.py +++ b/ppfleetx/models/multimodal_model/imagen/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .modeling import (ImagenModel, imagen_397M_text2im_64, +from .modeling import (ImagenModel, imagen_397M_text2im_64, imagen_2B_text2im_64, imagen_text2im_64_SR256, - imagen_SR256, imagen_SR1024, ImagenCriterion) + imagen_SR256, imagen_SR1024, imagen_SR512, ImagenCriterion) diff --git a/ppfleetx/models/multimodal_model/imagen/modeling.py b/ppfleetx/models/multimodal_model/imagen/modeling.py index bb0258274..4a8823225 100644 --- a/ppfleetx/models/multimodal_model/imagen/modeling.py +++ b/ppfleetx/models/multimodal_model/imagen/modeling.py @@ -814,6 +814,14 @@ def imagen_SR256(**kwargs): return model +def imagen_SR512(**kwargs): + model = ImagenModel(unets=SRUnet1024(), image_sizes=(512, ), **kwargs) + return model + def imagen_SR1024(**kwargs): model = ImagenModel(unets=SRUnet1024(), image_sizes=(1024, ), **kwargs) return model + +def imagen_SR64to1024(**kwargs): + model = ImagenModel(unets=SRUnet64to1024(), image_sizes=(1024, ), **kwargs) + return model diff --git a/ppfleetx/models/multimodal_model/multimodal_module.py b/ppfleetx/models/multimodal_model/multimodal_module.py index 43064f69f..8eff8a20e 100644 --- a/ppfleetx/models/multimodal_model/multimodal_module.py +++ b/ppfleetx/models/multimodal_model/multimodal_module.py @@ -27,9 +27,10 @@ class MultiModalModule(BasicModule): def __init__(self, configs): self.nranks = paddle.distributed.get_world_size() - super(MultiModalModule, self).__init__(configs) + self.loss_fn = self.get_loss_fn() + def process_configs(self, configs): configs = process_configs(configs) return configs diff --git a/projects/imagen/run_super_resolusion_1024_single.sh b/projects/imagen/run_super_resolusion_1024_single.sh new file mode 100644 index 000000000..e5b9d63f6 --- /dev/null +++ b/projects/imagen/run_super_resolusion_1024_single.sh @@ -0,0 +1,18 @@ +#! /bin/bash + +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +export CUDA_VISIBLE_DEVICES=0 +python3 tools/train.py -c ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml -o Data.Train.loader.num_workers=0 diff --git a/projects/imagen/run_super_resolusion_512_single.sh b/projects/imagen/run_super_resolusion_512_single.sh new file mode 100644 index 000000000..4a74ae642 --- /dev/null +++ b/projects/imagen/run_super_resolusion_512_single.sh @@ -0,0 +1,18 @@ +#! /bin/bash + +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +export CUDA_VISIBLE_DEVICES=0 +python3 tools/train.py -c ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_512.yaml -o Data.Train.loader.num_workers=8 \ No newline at end of file diff --git a/projects/imagen/run_text2im_397M_64x64_single.sh b/projects/imagen/run_text2im_397M_64x64_single.sh index 79ba52bfb..859688c9c 100644 --- a/projects/imagen/run_text2im_397M_64x64_single.sh +++ b/projects/imagen/run_text2im_397M_64x64_single.sh @@ -15,4 +15,4 @@ # limitations under the License. export CUDA_VISIBLE_DEVICES=0 -python tools/train.py -c ppfleetx/configs/multimodal/imagen/imagen_397M_text2im_64x64.yaml -o Data.Train.loader.num_workers=8 +python3 tools/train.py -c ppfleetx/configs/multimodal/imagen/imagen_397M_text2im_64x64.yaml -o Data.Train.loader.num_workers=8 From 9ca662c1546c302a6b4bb29cd4bb51fc6330a708 Mon Sep 17 00:00:00 2001 From: GhostScreaming Date: Mon, 19 Sep 2022 07:28:45 +0000 Subject: [PATCH 02/10] [WIP] Add recompute support for imagen model. --- .../imagen/imagen_super_resolusion_1024.yaml | 5 +- .../multimodal_model/imagen/modeling.py | 7 +- .../models/multimodal_model/imagen/unet.py | 70 ++++++++++++++++--- 3 files changed, 68 insertions(+), 14 deletions(-) diff --git a/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml b/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml index 658d03ad4..6141960c0 100644 --- a/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml +++ b/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml @@ -24,7 +24,8 @@ Model: dynamic_thresholding: True, dynamic_thresholding_percentile: 0.95 only_train_unet_number: 1 - use_recompute: False + use_recompute: True + recompute_granularity: full Data: Train: @@ -37,7 +38,7 @@ Data: loader: num_workers: 8 shuffle: True - batch_size: 1 + batch_size: 1 drop_last: True collate_fn: imagen_collate_fn diff --git a/ppfleetx/models/multimodal_model/imagen/modeling.py b/ppfleetx/models/multimodal_model/imagen/modeling.py index 4a8823225..939c59e11 100644 --- a/ppfleetx/models/multimodal_model/imagen/modeling.py +++ b/ppfleetx/models/multimodal_model/imagen/modeling.py @@ -153,6 +153,7 @@ def __init__(self, dynamic_thresholding_percentile=0.95, only_train_unet_number=None, use_recompute=False, + recompute_granularity="full", fused_linear=False): super().__init__() @@ -165,6 +166,9 @@ def __init__(self, self.channels = in_chans + # use recompute + self.use_recompute = use_recompute + # automatically take care of ensuring that first unet is unconditional # while the rest of the unets are conditioned on the low resolution image produced by previous unet @@ -691,7 +695,8 @@ def p_losses(self, lowres_noise_times=self.lowres_noise_schedule.get_condition( lowres_aug_times), lowres_cond_img=lowres_cond_img_noisy, - cond_drop_prob=self.cond_drop_prob, ) + cond_drop_prob=self.cond_drop_prob, + use_recompute=self.use_recompute) # prediction objective diff --git a/ppfleetx/models/multimodal_model/imagen/unet.py b/ppfleetx/models/multimodal_model/imagen/unet.py index 923f887d6..8d877dbe7 100644 --- a/ppfleetx/models/multimodal_model/imagen/unet.py +++ b/ppfleetx/models/multimodal_model/imagen/unet.py @@ -20,6 +20,7 @@ from paddle import nn from paddle import nn, einsum import paddle.nn.functional as F +from paddle.distributed.fleet.utils import recompute from .utils import (zeros_, zero_init_, default, exists, cast_tuple, resize_image_to, prob_mask_like, masked_mean, Identity, @@ -873,6 +874,8 @@ def __init__(self, # save locals to take care of some hyperparameters for cascading DDPM + self.count = 0 + self._locals = locals() self._locals.pop('self', None) self._locals.pop('__class__', None) @@ -1298,7 +1301,8 @@ def forward(self, text_embeds=None, text_mask=None, cond_images=None, - cond_drop_prob=0.): + cond_drop_prob=0., + use_recompute=False): batch_size = x.shape[0] # add low resolution conditioning, if present @@ -1340,6 +1344,8 @@ def forward(self, time_tokens = self.to_time_tokens(time_hiddens) t = self.to_time_cond(time_hiddens) + if use_recompute: + t.stop_gradient = True # add lowres time conditioning to time hiddens # and add lowres time tokens along sequence dimension for attention @@ -1426,6 +1432,8 @@ def forward(self, # normalize conditioning tokens c = self.norm_cond(c) + if use_recompute: + c.stop_gradient = True if exists(self.init_resnet_block): x = self.init_resnet_block(x, t) @@ -1434,19 +1442,35 @@ def forward(self, for pre_downsample, init_block, resnet_blocks, attn_block, post_downsample in self.downs: if exists(pre_downsample): - x = pre_downsample(x) + if use_recompute: + x = recompute(pre_downsample, x) + else: + x = pre_downsample(x) + - x = init_block(x, t, c) + if use_recompute: + x = init_block(x, t, c) + else: + x = recompute(init_block, x, t, c) for resnet_block in resnet_blocks: - x = resnet_block(x, t) + if use_recompute: + x = recompute(resnet_block, x, t) + else: + x = resnet_block(x, t) hiddens.append(x) - x = attn_block(x, c) + if use_recompute: + x = recompute(attn_block, x, c) + else: + x = attn_block(x, c) hiddens.append(x) if exists(post_downsample): - x = post_downsample(x) + if use_recompute: + x = recompute(post_downsample, x) + else: + x = post_downsample(x) x = self.mid_block1(x, t, c) @@ -1461,15 +1485,27 @@ def forward(self, for init_block, resnet_blocks, attn_block, upsample in self.ups: x = add_skip_connection(x) - x = init_block(x, t, c) + if use_recompute: + x = recompute(init_block, x, t, c) + else: + x = init_block(x, t, c) for resnet_block in resnet_blocks: x = add_skip_connection(x) - x = resnet_block(x, t) - - x = attn_block(x, c) + if use_recompute: + x = recompute(resnet_block, x, t) + else: + x = resnet_block(x, t) + + if use_recompute: + x = recompute(attn_block, x, c) + else: + x = attn_block(x, c) up_hiddens.append(x) - x = upsample(x) + if use_recompute: + x = recompute(upsample, x) + else: + x = upsample(x) x = self.upsample_combiner(x, up_hiddens) @@ -1481,5 +1517,17 @@ def forward(self, if exists(lowres_cond_img): x = paddle.concat((x, lowres_cond_img), axis=1) + + # output = self.final_conv(x) + + # import numpy as np + # np.save("origin_data/output.npy", output.numpy()) + + # if self.count == 1: + # for name + + # self.count += 1 + + # return output return self.final_conv(x) From 15b8c38c5cd37bb5e12ac1aae1e99130d817dc84 Mon Sep 17 00:00:00 2001 From: GhostScreaming Date: Tue, 20 Sep 2022 05:41:21 +0000 Subject: [PATCH 03/10] Add gradient-merge support. --- ppfleetx/core/engine/eager_engine.py | 32 ++++++++++++------- .../multimodal_model/imagen/modeling.py | 10 +++--- .../models/multimodal_model/imagen/unet.py | 17 ++-------- .../imagen/run_super_resolusion_1024_DP8.sh | 23 +++++++++++++ 4 files changed, 51 insertions(+), 31 deletions(-) create mode 100644 projects/imagen/run_super_resolusion_1024_DP8.sh diff --git a/ppfleetx/core/engine/eager_engine.py b/ppfleetx/core/engine/eager_engine.py index ece480e58..931394779 100644 --- a/ppfleetx/core/engine/eager_engine.py +++ b/ppfleetx/core/engine/eager_engine.py @@ -123,6 +123,7 @@ def configure_optimizers(self): self._test_iters = self._configs['test_iters'] self._logging_freq = self._configs['logging_freq'] self._num_train_epochs = self._configs['num_train_epochs'] + print("__init__ accumulation steps: ", self._configs['accumulate_steps']) self._accumulate_steps = self._configs['accumulate_steps'] self._use_pure_fp16 = self._configs['mix_precision']['use_pure_fp16'] @@ -267,7 +268,7 @@ def _train_one_epoch(self, if step < self._load_recovery['step']: continue - loss = self._fit_impl(batch) + loss = self._fit_impl(batch, step) # Sync for profile time, delete it may be a little faster paddle.device.cuda.synchronize() train_costs = time.time() - train_start @@ -380,23 +381,32 @@ def fit(self, epoch=1, train_data_loader=None, valid_data_loader=None): if self.profiler: self._profiler_done() - def _fit_impl(self, batch): + def _fit_impl(self, batch, step): batch = self._module.pretreating_batch(batch) if self._pp_degree == 1: + print("accumulate steps: ", self._accumulate_steps) + update_parameters = (step != 0 and step % self._accumulate_steps == 0) if self._use_recompute and isinstance(self._module.model, paddle.DataParallel): with self._module.model.no_sync(): loss = self._model_forward_backward(batch) - if not hasattr(self._optimizer, "all_fused_tensors" - ) or self._optimizer.all_fused_tensors is None: - fused_allreduce_gradients( - list(self._module.model.parameters()), None) - else: - all_reduce_parameters(self._optimizer.all_fused_tensors, - self._dp_group) + if update_parameters: + if not hasattr(self._optimizer, "all_fused_tensors" + ) or self._optimizer.all_fused_tensors is None: + fused_allreduce_gradients( + list(self._module.model.parameters()), None) + else: + all_reduce_parameters(self._optimizer.all_fused_tensors, + self._dp_group) else: - loss = self._model_forward_backward(batch) - self._optim_update_params() + if update_parameters: + loss = self._model_forward_backward(batch) + else: + with self._module.model.no_sync(): + loss = self._model_forward_backward(batch) + if update_parameters: + print("current step: {}, update parameters".format(step), "****" * 40) + self._optim_update_params() else: with paddle.amp.auto_cast( self._use_pure_fp16, diff --git a/ppfleetx/models/multimodal_model/imagen/modeling.py b/ppfleetx/models/multimodal_model/imagen/modeling.py index 939c59e11..e85a79bee 100644 --- a/ppfleetx/models/multimodal_model/imagen/modeling.py +++ b/ppfleetx/models/multimodal_model/imagen/modeling.py @@ -288,15 +288,15 @@ def get_unet(self, unet_number): assert 0 < unet_number <= len(self.unets) index = unet_number - 1 - if isinstance(self.unets, nn.LayerList): - unets_list = [unet for unet in self.unets] - delattr(self, 'unets') - self.unets = unets_list + # if isinstance(self.unets, nn.LayerList): + # unets_list = [unet for unet in self.unets] + # delattr(self, 'unets') + # self.unets = unets_list self.unet_being_trained_index = index return self.unets[index] def reset_unets(self, ): - self.unets = nn.LayerList([*self.unets]) + # self.unets = nn.LayerList([*self.unets]) self.unet_being_trained_index = -1 @contextmanager diff --git a/ppfleetx/models/multimodal_model/imagen/unet.py b/ppfleetx/models/multimodal_model/imagen/unet.py index 8d877dbe7..a1f111849 100644 --- a/ppfleetx/models/multimodal_model/imagen/unet.py +++ b/ppfleetx/models/multimodal_model/imagen/unet.py @@ -874,7 +874,6 @@ def __init__(self, # save locals to take care of some hyperparameters for cascading DDPM - self.count = 0 self._locals = locals() self._locals.pop('self', None) @@ -1449,9 +1448,9 @@ def forward(self, if use_recompute: - x = init_block(x, t, c) + x = recompute(init_block, x, t, c) else: - x = recompute(init_block, x, t, c) + x = init_block(x, t, c) for resnet_block in resnet_blocks: if use_recompute: @@ -1518,16 +1517,4 @@ def forward(self, if exists(lowres_cond_img): x = paddle.concat((x, lowres_cond_img), axis=1) - # output = self.final_conv(x) - - # import numpy as np - # np.save("origin_data/output.npy", output.numpy()) - - # if self.count == 1: - # for name - - # self.count += 1 - - - # return output return self.final_conv(x) diff --git a/projects/imagen/run_super_resolusion_1024_DP8.sh b/projects/imagen/run_super_resolusion_1024_DP8.sh new file mode 100644 index 000000000..fd1750703 --- /dev/null +++ b/projects/imagen/run_super_resolusion_1024_DP8.sh @@ -0,0 +1,23 @@ +#! /bin/bash + +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +log_dir=log_imagen_1024_DP8 +rm -rf $log_dir + +export CUDA_VISIBLE_DEVICES=0,1,2,3 +python -m paddle.distributed.launch --log_dir $log_dir --devices "0,1,2,3" \ + tools/train.py -c ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml \ + -o Data.Train.loader.num_workers=8 From 4ef3ded70ac7a705f79ea74222c496c80cfbf608 Mon Sep 17 00:00:00 2001 From: GhostScreaming Date: Tue, 20 Sep 2022 11:18:12 +0000 Subject: [PATCH 04/10] Fix some problems. --- .../multimodal/imagen/imagen_super_resolusion_1024.yaml | 4 ++-- ppfleetx/core/engine/eager_engine.py | 3 --- ppfleetx/models/multimodal_model/multimodal_module.py | 4 ++-- projects/imagen/run_super_resolusion_1024_DP8.sh | 4 ++-- 4 files changed, 6 insertions(+), 9 deletions(-) diff --git a/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml b/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml index cf99d7293..6a4168fda 100644 --- a/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml +++ b/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml @@ -2,7 +2,7 @@ _base_: ./imagen_base.yaml Global: global_batch_size: - local_batch_size: 1 + local_batch_size: 16 micro_batch_size: 1 @@ -48,7 +48,7 @@ Loss: p2_loss_weight_k: 1.0 Distributed: - dp_degree: 1 + dp_degree: 8 mp_degree: 1 pp_degree: 1 sharding: diff --git a/ppfleetx/core/engine/eager_engine.py b/ppfleetx/core/engine/eager_engine.py index 931394779..7dcb69413 100644 --- a/ppfleetx/core/engine/eager_engine.py +++ b/ppfleetx/core/engine/eager_engine.py @@ -123,7 +123,6 @@ def configure_optimizers(self): self._test_iters = self._configs['test_iters'] self._logging_freq = self._configs['logging_freq'] self._num_train_epochs = self._configs['num_train_epochs'] - print("__init__ accumulation steps: ", self._configs['accumulate_steps']) self._accumulate_steps = self._configs['accumulate_steps'] self._use_pure_fp16 = self._configs['mix_precision']['use_pure_fp16'] @@ -384,7 +383,6 @@ def fit(self, epoch=1, train_data_loader=None, valid_data_loader=None): def _fit_impl(self, batch, step): batch = self._module.pretreating_batch(batch) if self._pp_degree == 1: - print("accumulate steps: ", self._accumulate_steps) update_parameters = (step != 0 and step % self._accumulate_steps == 0) if self._use_recompute and isinstance(self._module.model, paddle.DataParallel): @@ -405,7 +403,6 @@ def _fit_impl(self, batch, step): with self._module.model.no_sync(): loss = self._model_forward_backward(batch) if update_parameters: - print("current step: {}, update parameters".format(step), "****" * 40) self._optim_update_params() else: with paddle.amp.auto_cast( diff --git a/ppfleetx/models/multimodal_model/multimodal_module.py b/ppfleetx/models/multimodal_model/multimodal_module.py index 8eff8a20e..2024518d7 100644 --- a/ppfleetx/models/multimodal_model/multimodal_module.py +++ b/ppfleetx/models/multimodal_model/multimodal_module.py @@ -47,7 +47,7 @@ def training_step(self, batch): return loss def training_step_end(self, log_dict): - speed = self.configs.Engine.logging_freq / log_dict['train_cost'] + speed = 1.0 / log_dict['train_cost'] logger.info( "[train] epoch: %d, batch: %d, loss: %.9f, avg_batch_cost: %.5f sec, speed: %.2f step/s, learning rate: %.5e" @@ -62,7 +62,7 @@ def validation_step(self, batch): return loss def validation_step_end(self, log_dict): - speed = self.configs.Engine.logging_freq / log_dict['eval_cost'] + speed = 1.0 / log_dict['eval_cost'] logger.info( "[eval] epoch: %d, batch: %d, loss: %.9f, avg_eval_cost: %.5f sec, speed: %.2f step/s" % (log_dict['epoch'], log_dict['batch'], log_dict['loss'], diff --git a/projects/imagen/run_super_resolusion_1024_DP8.sh b/projects/imagen/run_super_resolusion_1024_DP8.sh index fd1750703..b4fd8bbc7 100644 --- a/projects/imagen/run_super_resolusion_1024_DP8.sh +++ b/projects/imagen/run_super_resolusion_1024_DP8.sh @@ -17,7 +17,7 @@ log_dir=log_imagen_1024_DP8 rm -rf $log_dir -export CUDA_VISIBLE_DEVICES=0,1,2,3 -python -m paddle.distributed.launch --log_dir $log_dir --devices "0,1,2,3" \ +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +python -m paddle.distributed.launch --log_dir $log_dir --devices "0,1,2,3,4,5,6,7" \ tools/train.py -c ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml \ -o Data.Train.loader.num_workers=8 From 6476c2940480588ba080f7478ac8a4c58b2f903c Mon Sep 17 00:00:00 2001 From: GhostScreaming Date: Wed, 21 Sep 2022 12:53:50 +0000 Subject: [PATCH 05/10] Adapting imagen model for bfloat16 dtype. --- .../imagen/imagen_super_resolusion_1024.yaml | 7 ++++ ppfleetx/core/engine/eager_engine.py | 13 ++++--- .../models/multimodal_model/imagen/unet.py | 36 +++++++------------ .../multimodal_model/multimodal_module.py | 2 +- 4 files changed, 30 insertions(+), 28 deletions(-) diff --git a/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml b/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml index 6a4168fda..c573f1f85 100644 --- a/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml +++ b/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml @@ -27,6 +27,13 @@ Model: use_recompute: True recompute_granularity: full +Engine: + mix_precision: + use_pure_fp16: True + scale_loss: 32768.0 + custom_black_list: ["reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div"] + custom_white_list: ["lookup_table", "lookup_table_v2"] + Data: Train: dataset: diff --git a/ppfleetx/core/engine/eager_engine.py b/ppfleetx/core/engine/eager_engine.py index 7dcb69413..481b22118 100644 --- a/ppfleetx/core/engine/eager_engine.py +++ b/ppfleetx/core/engine/eager_engine.py @@ -162,7 +162,7 @@ def configure_optimizers(self): # Save dtype is the same as model dtype. Also can set save_dtype='float32' when # training with pure fp16 strategy, but will cause the rise of memory. self._module.model = paddle.amp.decorate( - models=self._module.model, level='O2') + models=self._module.model, level='O2', dtype='bfloat16') else: self._scaler = None @@ -382,6 +382,8 @@ def fit(self, epoch=1, train_data_loader=None, valid_data_loader=None): def _fit_impl(self, batch, step): batch = self._module.pretreating_batch(batch) + with paddle.no_grad(): + batch = [paddle.cast(t, dtype=paddle.bfloat16) if t.dtype == paddle.float32 else t for t in batch] if self._pp_degree == 1: update_parameters = (step != 0 and step % self._accumulate_steps == 0) if self._use_recompute and isinstance(self._module.model, @@ -397,7 +399,8 @@ def _fit_impl(self, batch, step): all_reduce_parameters(self._optimizer.all_fused_tensors, self._dp_group) else: - if update_parameters: + if update_parameters or not isinstance(self._module.model, + paddle.DataParallel): loss = self._model_forward_backward(batch) else: with self._module.model.no_sync(): @@ -409,7 +412,8 @@ def _fit_impl(self, batch, step): self._use_pure_fp16, custom_black_list=self._custom_black_list, custom_white_list=self._custom_white_list, - level='O2'): + level='O2', + dtype='bfloat16'): loss = self._module.model.train_batch( batch, optimizer=self._optimizer, @@ -422,7 +426,8 @@ def _model_forward_backward(self, batch): self._use_pure_fp16, custom_black_list=self._custom_black_list, custom_white_list=self._custom_white_list, - level='O2'): + level='O2', + dtype='bfloat16'): loss = self._module.training_step(batch) loss_bw = self._scaler.scale(loss) if self._use_pure_fp16 else loss diff --git a/ppfleetx/models/multimodal_model/imagen/unet.py b/ppfleetx/models/multimodal_model/imagen/unet.py index a1f111849..aebc6bad5 100644 --- a/ppfleetx/models/multimodal_model/imagen/unet.py +++ b/ppfleetx/models/multimodal_model/imagen/unet.py @@ -1441,11 +1441,7 @@ def forward(self, for pre_downsample, init_block, resnet_blocks, attn_block, post_downsample in self.downs: if exists(pre_downsample): - if use_recompute: - x = recompute(pre_downsample, x) - else: - x = pre_downsample(x) - + x = pre_downsample(x) if use_recompute: x = recompute(init_block, x, t, c) @@ -1453,10 +1449,7 @@ def forward(self, x = init_block(x, t, c) for resnet_block in resnet_blocks: - if use_recompute: - x = recompute(resnet_block, x, t) - else: - x = resnet_block(x, t) + x = resnet_block(x, t) hiddens.append(x) if use_recompute: @@ -1466,17 +1459,20 @@ def forward(self, hiddens.append(x) if exists(post_downsample): - if use_recompute: - x = recompute(post_downsample, x) - else: - x = post_downsample(x) + x = post_downsample(x) - x = self.mid_block1(x, t, c) + if use_recompute: + x = recompute(self.mid_block1, x, t, c) + else: + x = self.mid_block1(x, t, c) if exists(self.mid_attn): x = self.mid_attn(x) - x = self.mid_block2(x, t, c) + if use_recompute: + x = recompute(self.mid_block2, x, t, c) + else: + x = self.mid_block2(x, t, c) add_skip_connection = lambda x: paddle.concat((x, hiddens.pop() * self.skip_connect_scale), axis=1) @@ -1491,20 +1487,14 @@ def forward(self, for resnet_block in resnet_blocks: x = add_skip_connection(x) - if use_recompute: - x = recompute(resnet_block, x, t) - else: - x = resnet_block(x, t) + x = resnet_block(x, t) if use_recompute: x = recompute(attn_block, x, c) else: x = attn_block(x, c) up_hiddens.append(x) - if use_recompute: - x = recompute(upsample, x) - else: - x = upsample(x) + x = upsample(x) x = self.upsample_combiner(x, up_hiddens) diff --git a/ppfleetx/models/multimodal_model/multimodal_module.py b/ppfleetx/models/multimodal_model/multimodal_module.py index 2024518d7..a0ace54f2 100644 --- a/ppfleetx/models/multimodal_model/multimodal_module.py +++ b/ppfleetx/models/multimodal_model/multimodal_module.py @@ -76,7 +76,7 @@ def test_step(self, batch): return loss def test_step_end(self, log_dict): - speed = self.configs.Engine.logging_freq / log_dict['test_cost'] + speed = 1.0 / log_dict['test_cost'] logger.info( "[test] epoch: %d, batch: %d, loss: %.9f, avg_test_cost: %.5f sec, speed: %.2f step/s" % (log_dict['epoch'], log_dict['batch'], log_dict['loss'], From c5a8795022de1aee04e068931574f2c988c0c9a8 Mon Sep 17 00:00:00 2001 From: GhostScreaming Date: Wed, 28 Sep 2022 02:58:33 +0000 Subject: [PATCH 06/10] [WIP] add sharding and bfloat16 training strategy. --- .../imagen/imagen_super_resolusion_1024.yaml | 6 ++--- .../run_super_resolusion_1024_sharding.sh | 24 +++++++++++++++++++ 2 files changed, 27 insertions(+), 3 deletions(-) create mode 100644 projects/imagen/run_super_resolusion_1024_sharding.sh diff --git a/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml b/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml index c573f1f85..48a4aab57 100644 --- a/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml +++ b/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml @@ -2,7 +2,7 @@ _base_: ./imagen_base.yaml Global: global_batch_size: - local_batch_size: 16 + local_batch_size: 1 micro_batch_size: 1 @@ -24,7 +24,7 @@ Model: dynamic_thresholding: True, dynamic_thresholding_percentile: 0.95 only_train_unet_number: 1 - use_recompute: True + use_recompute: False recompute_granularity: full Engine: @@ -55,7 +55,7 @@ Loss: p2_loss_weight_k: 1.0 Distributed: - dp_degree: 8 + dp_degree: 1 mp_degree: 1 pp_degree: 1 sharding: diff --git a/projects/imagen/run_super_resolusion_1024_sharding.sh b/projects/imagen/run_super_resolusion_1024_sharding.sh new file mode 100644 index 000000000..2584b5c10 --- /dev/null +++ b/projects/imagen/run_super_resolusion_1024_sharding.sh @@ -0,0 +1,24 @@ +#! /bin/bash + +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +log_dir=log_imagen_1024_sharding +rm -rf $log_dir + +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +python -m paddle.distributed.launch --log_dir $log_dir --devices "0,1,2,3,4,5,6,7" \ + tools/train.py -c ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml \ + -o Data.Train.loader.num_workers=8 -o Distributed.sharding.sharding_degree=8 \ + -o Distributed.sharding.sharding_stage=2 From f79ef92a51838d6703e049ecd1ba1b65e79673b3 Mon Sep 17 00:00:00 2001 From: GhostScreaming Date: Tue, 25 Oct 2022 02:25:48 +0000 Subject: [PATCH 07/10] Polish Code. --- .../imagen/imagen_super_resolusion_1024.yaml | 1 + ppfleetx/core/engine/eager_engine.py | 77 ++++++++++--------- .../models/multimodal_model/imagen/unet.py | 10 +-- 3 files changed, 48 insertions(+), 40 deletions(-) diff --git a/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml b/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml index 48a4aab57..2a1b5ebe0 100644 --- a/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml +++ b/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml @@ -33,6 +33,7 @@ Engine: scale_loss: 32768.0 custom_black_list: ["reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div"] custom_white_list: ["lookup_table", "lookup_table_v2"] + fp16_dtype: "bfloat16" Data: Train: diff --git a/ppfleetx/core/engine/eager_engine.py b/ppfleetx/core/engine/eager_engine.py index 2ff30b683..d0b6dfa76 100644 --- a/ppfleetx/core/engine/eager_engine.py +++ b/ppfleetx/core/engine/eager_engine.py @@ -136,6 +136,8 @@ def configure_optimizers(self): 'custom_black_list'] self._custom_white_list = self._configs['mix_precision'][ 'custom_white_list'] + self._fp16_dtype = "float16" if 'fp16_dtype' in self._configs['mix_precision'] \ + else self._configs['mix_precision']['fp16_dtype'] self._save_steps = self._configs['save_load']['save_steps'] self._save_epoch = self._configs['save_load']['save_epoch'] @@ -149,7 +151,7 @@ def configure_optimizers(self): self._mp_degree = self._dist_configs['mp_degree'] self._pp_degree = self._dist_configs['pp_degree'] sharding_config = self._dist_configs['sharding'] - + self._sharding_stage = sharding_config['sharding_stage'] self._sharding_degree = sharding_config['sharding_degree'] self._sharding_offload = sharding_config['sharding_offload'] @@ -252,9 +254,8 @@ def _wrap_sharding_2_3(self): if self._reduce_overlap: self._module.model._set_reduce_overlap(self._reduce_overlap) if self._broadcast_overlap: - self._optimizer._set_broadcast_overlap(self._broadcast_overlap, - layers=origin_model, - num_groups=2) + self._optimizer._set_broadcast_overlap( + self._broadcast_overlap, layers=origin_model, num_groups=2) def _wrap_3D_parallel(self): self._module.model = fleet.distributed_model(self._module.model) @@ -283,10 +284,7 @@ def _train_one_epoch(self, continue loss = self._fit_impl(batch, step) - # Sync for profile time, delete it may be a little faster - paddle.device.cuda.synchronize() - train_costs = time.time() - train_start - train_losses.append(loss.numpy()[0]) + train_losses.append(loss) if (step + 1) % self._logging_freq == 0: # Sync for profile time, delete it may be a little faster @@ -405,38 +403,35 @@ def fit(self, epoch=1, train_data_loader=None, valid_data_loader=None): def _fit_impl(self, batch, step): batch = self._module.pretreating_batch(batch) - with paddle.no_grad(): - batch = [paddle.cast(t, dtype=paddle.bfloat16) if t.dtype == paddle.float32 else t for t in batch] + if self._fp16_dtype is 'bfloat16': + with paddle.no_grad(): + batch = [ + paddle.cast( + t, dtype=paddle.bfloat16) + if t.dtype == paddle.float32 else t for t in batch + ] if self._pp_degree == 1: - update_parameters = (step != 0 and step % self._accumulate_steps == 0) if self._use_recompute and isinstance(self._module.model, paddle.DataParallel): with self._module.model.no_sync(): loss = self._model_forward_backward(batch) - if update_parameters: - if not hasattr(self._optimizer, "all_fused_tensors" - ) or self._optimizer.all_fused_tensors is None: - fused_allreduce_gradients( - list(self._module.model.parameters()), None) - else: - all_reduce_parameters(self._optimizer.all_fused_tensors, - self._dp_group) - else: - if update_parameters or not isinstance(self._module.model, - paddle.DataParallel): - loss = self._model_forward_backward(batch) + if not hasattr(self._optimizer, "all_fused_tensors" + ) or self._optimizer.all_fused_tensors is None: + fused_allreduce_gradients( + list(self._module.model.parameters()), None) else: - with self._module.model.no_sync(): - loss = self._model_forward_backward(batch) - if update_parameters: - self._optim_update_params() + all_reduce_parameters(self._optimizer.all_fused_tensors, + self._dp_group) + else: + loss = self._model_forward_backward(batch) + self._optim_update_params() else: with paddle.amp.auto_cast( self._use_pure_fp16, custom_black_list=self._custom_black_list, custom_white_list=self._custom_white_list, level='O2', - dtype='bfloat16'): + dtype=self._fp16_dtype): loss = self._module.model.train_batch( batch, optimizer=self._optimizer, @@ -445,13 +440,25 @@ def _fit_impl(self, batch, step): return loss def _model_forward_backward(self, batch): - with paddle.amp.auto_cast( - self._use_pure_fp16, - custom_black_list=self._custom_black_list, - custom_white_list=self._custom_white_list, - level='O2', - dtype='bfloat16'): - loss = self._module.training_step(batch) + if self._accumulate_steps == 1 or self._pp_degree > 1: + batches = [batch] + else: + split_batches = [ + paddle.split(b, self._accumulate_steps) for b in batch + ] + batches = [] + for i in range(len(split_batches[0])): + micro_batch = [split_batch[i] for split_batch in split_batches] + batches.append(micro_batch) + final_loss = None + for micro_batch in batches: + with paddle.amp.auto_cast( + self._use_pure_fp16, + custom_black_list=self._custom_black_list, + custom_white_list=self._custom_white_list, + level='O2', + dtype=self._fp16_dtype): + loss = self._module.training_step(micro_batch) loss_bw = self._scaler.scale(loss) if self._use_pure_fp16 else loss self._module.backward(loss_bw) detach_loss = loss.detach() diff --git a/ppfleetx/models/multimodal_model/imagen/unet.py b/ppfleetx/models/multimodal_model/imagen/unet.py index aebc6bad5..c6a7e9f6e 100644 --- a/ppfleetx/models/multimodal_model/imagen/unet.py +++ b/ppfleetx/models/multimodal_model/imagen/unet.py @@ -874,7 +874,6 @@ def __init__(self, # save locals to take care of some hyperparameters for cascading DDPM - self._locals = locals() self._locals.pop('self', None) self._locals.pop('__class__', None) @@ -1022,8 +1021,9 @@ def __init__(self, layer_cross_attns = cast_tuple(layer_cross_attns, num_layers) assert all([ - layers == num_layers for layers in - list(map(len, (resnet_groups, layer_attns, layer_cross_attns))) + layers == num_layers + for layers in list( + map(len, (resnet_groups, layer_attns, layer_cross_attns))) ]) # downsample klass @@ -1444,7 +1444,7 @@ def forward(self, x = pre_downsample(x) if use_recompute: - x = recompute(init_block, x, t, c) + x = recompute(init_block, x, t, c) else: x = init_block(x, t, c) @@ -1506,5 +1506,5 @@ def forward(self, if exists(lowres_cond_img): x = paddle.concat((x, lowres_cond_img), axis=1) - + return self.final_conv(x) From 30099e6ebe3604f8c6eab687e69dcf1f85d8c6a3 Mon Sep 17 00:00:00 2001 From: GhostScreaming Date: Tue, 25 Oct 2022 02:36:21 +0000 Subject: [PATCH 08/10] Polish code. --- ppfleetx/core/engine/eager_engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ppfleetx/core/engine/eager_engine.py b/ppfleetx/core/engine/eager_engine.py index d0b6dfa76..bd4b32ab5 100644 --- a/ppfleetx/core/engine/eager_engine.py +++ b/ppfleetx/core/engine/eager_engine.py @@ -168,7 +168,7 @@ def configure_optimizers(self): # Save dtype is the same as model dtype. Also can set save_dtype='float32' when # training with pure fp16 strategy, but will cause the rise of memory. self._module.model = paddle.amp.decorate( - models=self._module.model, level='O2', dtype='bfloat16') + models=self._module.model, level='O2', dtype=self._fp16_dtype) else: self._scaler = None @@ -283,7 +283,7 @@ def _train_one_epoch(self, if step < self._load_recovery['step']: continue - loss = self._fit_impl(batch, step) + loss = self._fit_impl(batch) train_losses.append(loss) if (step + 1) % self._logging_freq == 0: @@ -401,7 +401,7 @@ def fit(self, epoch=1, train_data_loader=None, valid_data_loader=None): if self.profiler: self._profiler_done() - def _fit_impl(self, batch, step): + def _fit_impl(self, batch): batch = self._module.pretreating_batch(batch) if self._fp16_dtype is 'bfloat16': with paddle.no_grad(): From 8947f4eb440d6ddeb1809b2b723c36b59486de8f Mon Sep 17 00:00:00 2001 From: GhostScreaming Date: Tue, 25 Oct 2022 03:25:02 +0000 Subject: [PATCH 09/10] Fix bug of config.py and eager_engine.py --- ppfleetx/core/engine/eager_engine.py | 2 +- ppfleetx/utils/config.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ppfleetx/core/engine/eager_engine.py b/ppfleetx/core/engine/eager_engine.py index bd4b32ab5..089e8ba5a 100644 --- a/ppfleetx/core/engine/eager_engine.py +++ b/ppfleetx/core/engine/eager_engine.py @@ -136,7 +136,7 @@ def configure_optimizers(self): 'custom_black_list'] self._custom_white_list = self._configs['mix_precision'][ 'custom_white_list'] - self._fp16_dtype = "float16" if 'fp16_dtype' in self._configs['mix_precision'] \ + self._fp16_dtype = "float16" if 'fp16_dtype' not in self._configs['mix_precision'] \ else self._configs['mix_precision']['fp16_dtype'] self._save_steps = self._configs['save_load']['save_steps'] diff --git a/ppfleetx/utils/config.py b/ppfleetx/utils/config.py index 9e5998535..a8d46f866 100644 --- a/ppfleetx/utils/config.py +++ b/ppfleetx/utils/config.py @@ -55,7 +55,7 @@ def process_dist_config(configs): assert nranks == dp_degree * other_degree, \ "Mismatched config using {} cards with dp_degree[{}]," \ "mp_degree[{}], pp_degree[{}] and sharding_degree[{}]".format(nranks, \ - dp_degree, mp_degree, pp_degree, _sharding_degree) + dp_degree, mp_degree, pp_degree, sharding_degree) if sharding_config['sharding_degree'] > 1 and reduce_overlap: if sharding_config['sharding_stage'] == 3 or sharding_config[ From 1caba29cd3fdd9696cd82ad2edb71156e3dc16be Mon Sep 17 00:00:00 2001 From: GhostScreaming Date: Thu, 27 Oct 2022 11:10:48 +0800 Subject: [PATCH 10/10] Polish Code. --- .../imagen/imagen_super_resolusion_1024.yaml | 4 ++-- ppfleetx/core/engine/eager_engine.py | 16 ++++++++-------- ppfleetx/utils/config.py | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml b/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml index 2a1b5ebe0..b96376a45 100644 --- a/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml +++ b/ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml @@ -60,6 +60,6 @@ Distributed: mp_degree: 1 pp_degree: 1 sharding: - sharding_degree: 1 - sharding_stage: 1 + sharding_degree: 8 + sharding_stage: 2 sharding_offload: False diff --git a/ppfleetx/core/engine/eager_engine.py b/ppfleetx/core/engine/eager_engine.py index bd4b32ab5..6a7a57e2d 100644 --- a/ppfleetx/core/engine/eager_engine.py +++ b/ppfleetx/core/engine/eager_engine.py @@ -136,7 +136,7 @@ def configure_optimizers(self): 'custom_black_list'] self._custom_white_list = self._configs['mix_precision'][ 'custom_white_list'] - self._fp16_dtype = "float16" if 'fp16_dtype' in self._configs['mix_precision'] \ + self._fp16_dtype = "float16" if 'fp16_dtype' not in self._configs['mix_precision'] \ else self._configs['mix_precision']['fp16_dtype'] self._save_steps = self._configs['save_load']['save_steps'] @@ -160,6 +160,13 @@ def configure_optimizers(self): self._use_recompute = configs['Model']['use_recompute'] + self._lr_scheduler = build_lr_scheduler( + configs.Optimizer.lr) if mode == 'train' else None + + self._optimizer = build_optimizer( + configs.Optimizer, self._module.model, + self._lr_scheduler) if mode == 'train' else None + if self._use_pure_fp16: if mode == 'train': self._scaler = paddle.amp.GradScaler( @@ -172,13 +179,6 @@ def configure_optimizers(self): else: self._scaler = None - self._lr_scheduler = build_lr_scheduler( - configs.Optimizer.lr) if mode == 'train' else None - - self._optimizer = build_optimizer( - configs.Optimizer, self._module.model, - self._lr_scheduler) if mode == 'train' else None - # distributed configs self._distributed = (dist.get_world_size() > 1) diff --git a/ppfleetx/utils/config.py b/ppfleetx/utils/config.py index 9e5998535..a8d46f866 100644 --- a/ppfleetx/utils/config.py +++ b/ppfleetx/utils/config.py @@ -55,7 +55,7 @@ def process_dist_config(configs): assert nranks == dp_degree * other_degree, \ "Mismatched config using {} cards with dp_degree[{}]," \ "mp_degree[{}], pp_degree[{}] and sharding_degree[{}]".format(nranks, \ - dp_degree, mp_degree, pp_degree, _sharding_degree) + dp_degree, mp_degree, pp_degree, sharding_degree) if sharding_config['sharding_degree'] > 1 and reduce_overlap: if sharding_config['sharding_stage'] == 3 or sharding_config[