Skip to content

[bug] The tensor parallelism is slow and the memory of a single card has not been reduced. #403

@coder4nlp

Description

@coder4nlp

torchrun --nproc_per_node=2 test1110.py --height 1024 --width 1024 --steps 50 --attn flash --parallel-type tp
单卡显存都在43G,运行时间42s

使用单个进程时,显存56G,运行时间22s

import os
import sys
import argparse

import torch
import torch.distributed as dist

import cache_dit
from cache_dit import init_logger
from cache_dit.parallelism.parallel_backend import ParallelismBackend
sys.path.append("..")

import time

import torch
from diffusers import (
    QwenImagePipeline,
    QwenImageTransformer2DModel,
    AutoencoderKLQwenImage,
)
def GiB():
    try:
        if not torch.cuda.is_available():
            return 0
        total_memory_bytes = torch.cuda.get_device_properties(
            torch.cuda.current_device(),
        ).total_memory
        total_memory_gib = total_memory_bytes / (1024**3)
        return int(total_memory_gib)
    except Exception:
        return 0

def get_args(
    parse: bool = True,
) -> argparse.ArgumentParser | argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("--cache", action="store_true", default=False)
    parser.add_argument("--compile", action="store_true", default=False)
    parser.add_argument("--fuse-lora", action="store_true", default=False)
    parser.add_argument("--steps", type=int, default=None)
    parser.add_argument("--Fn", type=int, default=8)
    parser.add_argument("--Bn", type=int, default=0)
    parser.add_argument("--rdt", type=float, default=0.08)
    parser.add_argument("--max-warmup-steps", "--w", type=int, default=8)
    parser.add_argument("--max-cached-steps", "--mc", type=int, default=-1)
    parser.add_argument(
        "--max-continuous-cached-steps", "--mcc", type=int, default=-1
    )
    parser.add_argument("--taylorseer", action="store_true", default=False)
    parser.add_argument("--taylorseer-order", "-order", type=int, default=1)
    parser.add_argument("--height", type=int, default=None)
    parser.add_argument("--width", type=int, default=None)
    parser.add_argument("--quantize", "-q", action="store_true", default=False)
    # float8, float8_weight_only, int8, int8_weight_only, int4, int4_weight_only
    parser.add_argument(
        "--quantize-type",
        type=str,
        default="float8_weight_only",
        choices=[
            "float8",
            "float8_weight_only",
            "int8",
            "int8_weight_only",
            "int4",
            "int4_weight_only",
        ],
    )
    parser.add_argument(
        "--parallel-type",
        "--parallel",
        type=str,
        default=None,
        choices=[
            None,
            "tp",
            "ulysses",
            "ring",
        ],
    )
    parser.add_argument(
        "--attn",  # attention backend for context parallelism
        type=str,
        default=None,
        choices=[
            None,
            "flash",
            # Based on this fix: https://github.com/huggingface/diffusers/pull/12563
            "native",  # native pytorch attention: sdpa
            "_native_cudnn",
        ],
    )
    parser.add_argument("--perf", action="store_true", default=False)
    return parser.parse_args() if parse else parser

def cachify(
    args,
    pipe_or_adapter,
    **kwargs,
):
    if args.cache or args.parallel_type is not None:
        import torch.distributed as dist

        from cache_dit import (
            DBCacheConfig,
            ParallelismConfig,
            TaylorSeerCalibratorConfig,
        )

        cache_config = kwargs.pop("cache_config", None)
        parallelism_config = kwargs.pop("parallelism_config", None)

        backend = (
            ParallelismBackend.NATIVE_PYTORCH
            if args.parallel_type in ["tp"]
            else ParallelismBackend.NATIVE_DIFFUSER
        )
        parallel_kwargs = (
            {
                "attention_backend": (
                    "_native_cudnn" if not args.attn else args.attn
                )
            }
            if backend == ParallelismBackend.NATIVE_DIFFUSER
            else None
        )
        cache_dit.enable_cache(
            pipe_or_adapter,
            cache_config=(
                DBCacheConfig(
                    Fn_compute_blocks=args.Fn,
                    Bn_compute_blocks=args.Bn,
                    max_warmup_steps=args.max_warmup_steps,
                    max_cached_steps=args.max_cached_steps,
                    max_continuous_cached_steps=args.max_continuous_cached_steps,
                    residual_diff_threshold=args.rdt,
                    enable_separate_cfg=kwargs.get("enable_separate_cfg", None),
                )
                if cache_config is None and args.cache
                else cache_config
            ),
            calibrator_config=(
                TaylorSeerCalibratorConfig(
                    taylorseer_order=args.taylorseer_order,
                )
                if args.taylorseer
                else None
            ),
            parallelism_config=(
                ParallelismConfig(
                    ulysses_size=(
                        dist.get_world_size()
                        if args.parallel_type == "ulysses"
                        else None
                    ),
                    ring_size=(
                        dist.get_world_size()
                        if args.parallel_type == "ring"
                        else None
                    ),
                    tp_size=(
                        dist.get_world_size()
                        if args.parallel_type == "tp"
                        else None
                    ),
                    backend=backend,
                    parallel_kwargs=parallel_kwargs,
                )
                if parallelism_config is None
                and args.parallel_type in ["ulysses", "ring", "tp"]
                else parallelism_config
            ),
        )

    return pipe_or_adapter

def strify(args, pipe_or_stats):
    quantize_type = args.quantize_type if args.quantize else ""
    if quantize_type != "":
        quantize_type = f"_{quantize_type}"
    return (
        f"C{int(args.compile)}_Q{int(args.quantize)}{quantize_type}_"
        f"{cache_dit.strify(pipe_or_stats)}"
    )

def maybe_init_distributed(args=None):
    if args is not None:
        if args.parallel_type is not None:
            dist.init_process_group(
                backend="nccl",
            )
            rank = dist.get_rank()
            device = torch.device("cuda", rank % torch.cuda.device_count())
            torch.cuda.set_device(device)
            return rank, device
    else:
        # always init distributed for other examples
        if not dist.is_initialized():
            dist.init_process_group(
                backend="nccl",
            )
        rank = dist.get_rank()
        device = torch.device("cuda", rank % torch.cuda.device_count())
        torch.cuda.set_device(device)
        return rank, device
    return 0, torch.device("cuda" if torch.cuda.is_available() else "cpu")

def maybe_destroy_distributed():
    if dist.is_initialized():
        dist.destroy_process_group()
import cache_dit

args = get_args()
print(args)

rank, device = maybe_init_distributed(args)

pipe: QwenImagePipeline = QwenImagePipeline.from_pretrained(
    os.environ.get(
        "QWEN_IMAGE_DIR",
        "/models/Qwen/Qwen-Image"
    ),
    torch_dtype=torch.bfloat16,
)

assert isinstance(pipe.transformer, QwenImageTransformer2DModel)

# enable_quatization = args.quantize and GiB() < 96




# Apply cache and tensor parallelism here
if args.cache or args.parallel_type is not None:
    cachify(args, pipe)



pipe.to(device)
positive_magic = {
    "en": ", Ultra HD, 4K, cinematic composition.",  # for english prompt
    "zh": ", 超清,4K,电影级构图.",  # for chinese prompt
}

# Generate image
prompt = """A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition"""

# using an empty string if you do not have specific concept to remove
negative_prompt = " "

pipe.set_progress_bar_config(disable=rank != 0)

def run_pipe(warmup: bool = False):
    # do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
    output = pipe(
        prompt=prompt + positive_magic["en"],
        negative_prompt=negative_prompt,
        width=1024 if args.width is None else args.width,
        height=1024 if args.height is None else args.height,
        num_inference_steps=(
            (50 if args.steps is None else args.steps) if not warmup else 5
        ),
        true_cfg_scale=4.0,
        generator=torch.Generator(device="cpu").manual_seed(0),
        output_type="latent" if args.perf else "pil",
    )
    image = output.images[0] if not args.perf else None
    return image

if args.compile:
    cache_dit.set_compile_configs()
    pipe.transformer = torch.compile(pipe.transformer)

# warmup
_ = run_pipe(warmup=True)

start = time.time()
image = run_pipe()
end = time.time()

if rank == 0:
    cache_dit.summary(pipe)
    time_cost = end - start
    save_path = f"qwen-image.{strify(args, pipe)}.png"
    print(f"Time cost: {time_cost:.2f}s")
    if not args.perf:
        print(f"Saving image to {save_path}")
        image.save(save_path)

maybe_destroy_distributed()`

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions