-
Notifications
You must be signed in to change notification settings - Fork 36
Open
Description
torchrun --nproc_per_node=2 test1110.py --height 1024 --width 1024 --steps 50 --attn flash --parallel-type tp
单卡显存都在43G,运行时间42s
使用单个进程时,显存56G,运行时间22s
import os
import sys
import argparse
import torch
import torch.distributed as dist
import cache_dit
from cache_dit import init_logger
from cache_dit.parallelism.parallel_backend import ParallelismBackend
sys.path.append("..")
import time
import torch
from diffusers import (
QwenImagePipeline,
QwenImageTransformer2DModel,
AutoencoderKLQwenImage,
)
def GiB():
try:
if not torch.cuda.is_available():
return 0
total_memory_bytes = torch.cuda.get_device_properties(
torch.cuda.current_device(),
).total_memory
total_memory_gib = total_memory_bytes / (1024**3)
return int(total_memory_gib)
except Exception:
return 0
def get_args(
parse: bool = True,
) -> argparse.ArgumentParser | argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--cache", action="store_true", default=False)
parser.add_argument("--compile", action="store_true", default=False)
parser.add_argument("--fuse-lora", action="store_true", default=False)
parser.add_argument("--steps", type=int, default=None)
parser.add_argument("--Fn", type=int, default=8)
parser.add_argument("--Bn", type=int, default=0)
parser.add_argument("--rdt", type=float, default=0.08)
parser.add_argument("--max-warmup-steps", "--w", type=int, default=8)
parser.add_argument("--max-cached-steps", "--mc", type=int, default=-1)
parser.add_argument(
"--max-continuous-cached-steps", "--mcc", type=int, default=-1
)
parser.add_argument("--taylorseer", action="store_true", default=False)
parser.add_argument("--taylorseer-order", "-order", type=int, default=1)
parser.add_argument("--height", type=int, default=None)
parser.add_argument("--width", type=int, default=None)
parser.add_argument("--quantize", "-q", action="store_true", default=False)
# float8, float8_weight_only, int8, int8_weight_only, int4, int4_weight_only
parser.add_argument(
"--quantize-type",
type=str,
default="float8_weight_only",
choices=[
"float8",
"float8_weight_only",
"int8",
"int8_weight_only",
"int4",
"int4_weight_only",
],
)
parser.add_argument(
"--parallel-type",
"--parallel",
type=str,
default=None,
choices=[
None,
"tp",
"ulysses",
"ring",
],
)
parser.add_argument(
"--attn", # attention backend for context parallelism
type=str,
default=None,
choices=[
None,
"flash",
# Based on this fix: https://github.com/huggingface/diffusers/pull/12563
"native", # native pytorch attention: sdpa
"_native_cudnn",
],
)
parser.add_argument("--perf", action="store_true", default=False)
return parser.parse_args() if parse else parser
def cachify(
args,
pipe_or_adapter,
**kwargs,
):
if args.cache or args.parallel_type is not None:
import torch.distributed as dist
from cache_dit import (
DBCacheConfig,
ParallelismConfig,
TaylorSeerCalibratorConfig,
)
cache_config = kwargs.pop("cache_config", None)
parallelism_config = kwargs.pop("parallelism_config", None)
backend = (
ParallelismBackend.NATIVE_PYTORCH
if args.parallel_type in ["tp"]
else ParallelismBackend.NATIVE_DIFFUSER
)
parallel_kwargs = (
{
"attention_backend": (
"_native_cudnn" if not args.attn else args.attn
)
}
if backend == ParallelismBackend.NATIVE_DIFFUSER
else None
)
cache_dit.enable_cache(
pipe_or_adapter,
cache_config=(
DBCacheConfig(
Fn_compute_blocks=args.Fn,
Bn_compute_blocks=args.Bn,
max_warmup_steps=args.max_warmup_steps,
max_cached_steps=args.max_cached_steps,
max_continuous_cached_steps=args.max_continuous_cached_steps,
residual_diff_threshold=args.rdt,
enable_separate_cfg=kwargs.get("enable_separate_cfg", None),
)
if cache_config is None and args.cache
else cache_config
),
calibrator_config=(
TaylorSeerCalibratorConfig(
taylorseer_order=args.taylorseer_order,
)
if args.taylorseer
else None
),
parallelism_config=(
ParallelismConfig(
ulysses_size=(
dist.get_world_size()
if args.parallel_type == "ulysses"
else None
),
ring_size=(
dist.get_world_size()
if args.parallel_type == "ring"
else None
),
tp_size=(
dist.get_world_size()
if args.parallel_type == "tp"
else None
),
backend=backend,
parallel_kwargs=parallel_kwargs,
)
if parallelism_config is None
and args.parallel_type in ["ulysses", "ring", "tp"]
else parallelism_config
),
)
return pipe_or_adapter
def strify(args, pipe_or_stats):
quantize_type = args.quantize_type if args.quantize else ""
if quantize_type != "":
quantize_type = f"_{quantize_type}"
return (
f"C{int(args.compile)}_Q{int(args.quantize)}{quantize_type}_"
f"{cache_dit.strify(pipe_or_stats)}"
)
def maybe_init_distributed(args=None):
if args is not None:
if args.parallel_type is not None:
dist.init_process_group(
backend="nccl",
)
rank = dist.get_rank()
device = torch.device("cuda", rank % torch.cuda.device_count())
torch.cuda.set_device(device)
return rank, device
else:
# always init distributed for other examples
if not dist.is_initialized():
dist.init_process_group(
backend="nccl",
)
rank = dist.get_rank()
device = torch.device("cuda", rank % torch.cuda.device_count())
torch.cuda.set_device(device)
return rank, device
return 0, torch.device("cuda" if torch.cuda.is_available() else "cpu")
def maybe_destroy_distributed():
if dist.is_initialized():
dist.destroy_process_group()
import cache_dit
args = get_args()
print(args)
rank, device = maybe_init_distributed(args)
pipe: QwenImagePipeline = QwenImagePipeline.from_pretrained(
os.environ.get(
"QWEN_IMAGE_DIR",
"/models/Qwen/Qwen-Image"
),
torch_dtype=torch.bfloat16,
)
assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
# enable_quatization = args.quantize and GiB() < 96
# Apply cache and tensor parallelism here
if args.cache or args.parallel_type is not None:
cachify(args, pipe)
pipe.to(device)
positive_magic = {
"en": ", Ultra HD, 4K, cinematic composition.", # for english prompt
"zh": ", 超清,4K,电影级构图.", # for chinese prompt
}
# Generate image
prompt = """A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition"""
# using an empty string if you do not have specific concept to remove
negative_prompt = " "
pipe.set_progress_bar_config(disable=rank != 0)
def run_pipe(warmup: bool = False):
# do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
output = pipe(
prompt=prompt + positive_magic["en"],
negative_prompt=negative_prompt,
width=1024 if args.width is None else args.width,
height=1024 if args.height is None else args.height,
num_inference_steps=(
(50 if args.steps is None else args.steps) if not warmup else 5
),
true_cfg_scale=4.0,
generator=torch.Generator(device="cpu").manual_seed(0),
output_type="latent" if args.perf else "pil",
)
image = output.images[0] if not args.perf else None
return image
if args.compile:
cache_dit.set_compile_configs()
pipe.transformer = torch.compile(pipe.transformer)
# warmup
_ = run_pipe(warmup=True)
start = time.time()
image = run_pipe()
end = time.time()
if rank == 0:
cache_dit.summary(pipe)
time_cost = end - start
save_path = f"qwen-image.{strify(args, pipe)}.png"
print(f"Time cost: {time_cost:.2f}s")
if not args.perf:
print(f"Saving image to {save_path}")
image.save(save_path)
maybe_destroy_distributed()`Metadata
Metadata
Assignees
Labels
No labels