Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 78 additions & 12 deletions OmniGen/model.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
# The code is revised from DiT
import os
import gc
import warnings
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
import math
from typing import Dict

from diffusers.loaders import PeftAdapterMixin
from diffusers.utils import logging
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from accelerate import init_empty_weights
from transformers import BitsAndBytesConfig

from OmniGen.transformer import Phi3Config, Phi3Transformer
from OmniGen.utils import quantize_bnb


logger = logging.get_logger(__name__)

def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

Expand Down Expand Up @@ -162,6 +171,7 @@ def __init__(
pos_embed_max_size: int = 192,
):
super().__init__()
self.config = transformer_config
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
Expand All @@ -185,22 +195,78 @@ def __init__(

self.llm = Phi3Transformer(config=transformer_config)
self.llm.config.use_cache = False

# bnb quantized models cannot easily be offloaded or recast
self.quantized = False
self.dtype = None

@classmethod
def from_pretrained(cls, model_name):
if not os.path.exists(model_name):
def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = None, quantization_config: BitsAndBytesConfig = None, low_cpu_mem_usage: bool = True,):
model_path = Path(model_name)
config_loc = model_name # these only diverge when model_name is *.safetensors or *.pt file

if model_path.exists():
if model_path.is_dir():
if (weights_loc := list(model_path.glob('*.safetensors'))):
model_path = weights_loc[0]
elif (weights_loc := list(model_path.glob('*.pt'))):
model_path = weights_loc[0]
else:
raise FileNotFoundError(f'No .safetensors or .pt model weights found in {model_path.as_posix()!r}')
else:
logger.info("Loading model weights from file. Using default config from 'Shitao/OmniGen-v1'.")
config_loc = "Shitao/OmniGen-v1"
else:
cache_folder = os.getenv('HF_HUB_CACHE')
model_name = snapshot_download(repo_id=model_name,
cache_dir=cache_folder,
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
config = Phi3Config.from_pretrained(model_name)
model = cls(config)
if os.path.exists(os.path.join(model_name, 'model.safetensors')):
print("Loading safetensors")
ckpt = load_file(os.path.join(model_name, 'model.safetensors'))
model_path = snapshot_download(repo_id=model_name, cache_dir=cache_folder,
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])

# assume hub files are always .safetensors
model_path = next(Path(model_path).glob('*.safetensors'))

ckpt = (load_file(model_path, 'cpu') if model_path.suffix == '.safetensors' else
torch.load(model_path, map_location='cpu'))

config = Phi3Config.from_pretrained(config_loc)
# avoid inadvertently leaving the weights as float32
if dtype is None:
dtype = config.torch_dtype

if hasattr(config, 'quantization_config'):
if quantization_config is not None:
# from: diffusers.quantizers.auto
warnings.warn(
"You passed `quantization_config` or equivalent parameters to `from_pretrained` but the model you're loading"
" already has a `quantization_config` attribute. The `quantization_config` from the model will be used."
)

config.quantization_config.pop("quant_method",None) # prevent unused keys warning
quantization_config = BitsAndBytesConfig.from_dict(config.quantization_config)

if low_cpu_mem_usage:
with init_empty_weights():
model = cls(config)

if quantization_config:
model = quantize_bnb(model, ckpt, quantization_config=quantization_config, dtype=dtype)
model.quantized = True
model.config.quantization_config = quantization_config
else:
model.load_state_dict(ckpt, assign=True)
else:
ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu')
model.load_state_dict(ckpt)
if quantization_config:
raise ValueError('Quantization not supported for `low_cpu_mem_usage=False`.')

model = cls(config)
model.load_state_dict(ckpt)


# determine dtype via x_emb bias since as a Conv2d bias, it should never be quantized
model.dtype = model.x_embedder.proj.bias.dtype

del ckpt
torch.cuda.empty_cache()
gc.collect()
return model

def initialize_weights(self):
Expand Down
121 changes: 81 additions & 40 deletions OmniGen/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union, Literal
import gc

from PIL import Image
Expand All @@ -17,6 +18,7 @@
scale_lora_layers,
unscale_lora_layers,
)
from transformers import BitsAndBytesConfig
from safetensors.torch import load_file

from OmniGen import OmniGen, OmniGenProcessor, OmniGenScheduler
Expand All @@ -41,6 +43,15 @@
```
"""

def best_available_device():
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
logger.info("Don't detect any available GPUs, using CPU instead, this may take long time to generate image!!!")
device = torch.device("cpu")
return device

class OmniGenPipeline:
def __init__(
Expand All @@ -55,14 +66,10 @@ def __init__(
self.processor = processor
self.device = device

if device is None:
if torch.cuda.is_available():
self.device = torch.device("cuda")
elif torch.backends.mps.is_available():
self.device = torch.device("mps")
else:
logger.info("Don't detect any available GPUs, using CPU instead, this may take long time to generate image!!!")
self.device = torch.device("cpu")
if self.device is None:
self.device = best_available_device()
elif isinstance(self.device, str):
self.device = torch.device(self.device)

# self.model.to(torch.bfloat16)
self.model.eval()
Expand All @@ -71,28 +78,46 @@ def __init__(
self.model_cpu_offload = False

@classmethod
def from_pretrained(cls, model_name, vae_path: str=None):
if not os.path.exists(model_name) or (not os.path.exists(os.path.join(model_name, 'model.safetensors')) and model_name == "Shitao/OmniGen-v1"):
# logger.info("Model not found, downloading...")
print("Model not found, downloading...")
cache_folder = os.getenv('HF_HUB_CACHE')
model_name = snapshot_download(repo_id=model_name,
cache_dir=cache_folder,
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5', 'model.pt'])
# logger.info(f"Downloaded model to {model_name}")
print(f"Downloaded model to {model_name}")
model = OmniGen.from_pretrained(model_name)
processor = OmniGenProcessor.from_pretrained(model_name)

if os.path.exists(os.path.join(model_name, "vae")):
vae = AutoencoderKL.from_pretrained(os.path.join(model_name, "vae"))
elif vae_path is not None:
vae = AutoencoderKL.from_pretrained(vae_path).to(device)
else:
logger.info(f"No VAE found in {model_name}, downloading stabilityai/sdxl-vae from HF")
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device)
def from_pretrained(cls, model_name, vae_path: str=None, device=None, low_cpu_mem_usage=True, **kwargs):
pretrained_path = Path(model_name)

# XXX: Consider renaming 'model' to 'transformer' conform to diffusers pipeline syntax
model = kwargs.get('model', None)
processor = kwargs.get('processor', None)
vae = kwargs.get('vae', None)

return cls(vae, model, processor)
# NOTE: should technically allow delayed component inits via model/vae = None, but seems like more of a footgun than it's worth at this point

if not pretrained_path.exists():
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5', 'model.pt']

if model is not None:
ignore_patterns.append('model.safetensors') # avoid downloading bf16 model if passing existing model

logger.info("Model not found, downloading...")
cache_folder = os.getenv('HF_HUB_CACHE')
pretrained_path = Path(snapshot_download(repo_id=model_name, cache_dir=cache_folder, ignore_patterns=ignore_patterns))
logger.info(f"Downloaded model to {pretrained_path}")

if model is None:
model = OmniGen.from_pretrained(pretrained_path, dtype=torch.bfloat16, quantization_config=None, low_cpu_mem_usage=low_cpu_mem_usage)

model = model.requires_grad_(False).eval()

if processor is None:
processor = OmniGenProcessor.from_pretrained(model_name)

if vae is None:
if vae_path is None:
vae_path = pretrained_path.joinpath("vae")

if not os.path.exists(vae_path):
logger.info(f"No VAE found in {model_name}, downloading stabilityai/sdxl-vae from HF")
vae_path = "stabilityai/sdxl-vae"

vae = AutoencoderKL.from_pretrained(vae_path, low_cpu_mem_usage=low_cpu_mem_usage)

return cls(vae, model, processor, device)

def merge_lora(self, lora_path: str):
model = PeftModel.from_pretrained(self.model, lora_path)
Expand Down Expand Up @@ -123,7 +148,8 @@ def move_to_device(self, data):

def enable_model_cpu_offload(self):
self.model_cpu_offload = True
self.model.to("cpu")
if not self.model.quantized:
self.model.to("cpu")
self.vae.to("cpu")
torch.cuda.empty_cache() # Clear VRAM
gc.collect() # Run garbage collection to free system RAM
Expand Down Expand Up @@ -212,7 +238,13 @@ def __call__(
# set model and processor
if max_input_image_size != self.processor.max_image_size:
self.processor = OmniGenProcessor(self.processor.text_tokenizer, max_image_size=max_input_image_size)
self.model.to(dtype)

if not self.model.quantized:
self.model.dtype = dtype
self.model.to(dtype)

#self.vae.to(dtype) # Uncomment this line to allow bfloat16 VAE

if offload_model:
self.enable_model_cpu_offload()
else:
Expand All @@ -234,20 +266,20 @@ def __call__(
else:
generator = None
latents = torch.randn(num_prompt, 4, latent_size_h, latent_size_w, device=self.device, generator=generator)
latents = torch.cat([latents]*(1+num_cfg), 0).to(dtype)
latents = torch.cat([latents]*(1+num_cfg), 0).to(self.model.dtype)

if input_images is not None and self.model_cpu_offload: self.vae.to(self.device)
input_img_latents = []
if separate_cfg_infer:
for temp_pixel_values in input_data['input_pixel_values']:
temp_input_latents = []
for img in temp_pixel_values:
img = self.vae_encode(img.to(self.device), dtype)
img = self.vae_encode(img.to(self.vae.device, self.vae.dtype), self.model.dtype)
temp_input_latents.append(img)
input_img_latents.append(temp_input_latents)
else:
for img in input_data['input_pixel_values']:
img = self.vae_encode(img.to(self.device), dtype)
img = self.vae_encode(img.to(self.vae.device, self.vae.dtype), self.model.dtype)
input_img_latents.append(img)
if input_images is not None and self.model_cpu_offload:
self.vae.to('cpu')
Expand All @@ -263,36 +295,45 @@ def __call__(
img_cfg_scale=img_guidance_scale,
use_img_cfg=use_img_guidance,
use_kv_cache=use_kv_cache,
offload_model=offload_model,
offload_model=(offload_model and not self.model.quantized),
)

if separate_cfg_infer:
func = self.model.forward_with_separate_cfg
else:
func = self.model.forward_with_cfg

if self.model_cpu_offload:
if self.model_cpu_offload and not self.model.quantized:
for name, param in self.model.named_parameters():
if 'layers' in name and 'layers.0' not in name:
param.data = param.data.cpu()
param.data = param.data.to('cpu')
else:
param.data = param.data.to(self.device)
for buffer_name, buffer in self.model.named_buffers():
setattr(self.model, buffer_name, buffer.to(self.device))
torch.cuda.empty_cache()
gc.collect()
# else:
# self.model.to(self.device)

scheduler = OmniGenScheduler(num_steps=num_inference_steps)
if latents.dtype == torch.float16:
# Continue to monitor. If _clip_val never changes, can remove scheduler autoset func and just hardcode clip val here.
#self.model.llm.set_clip_val(2**16-32 - 2*32) # hardcode clip val
# dry run the inputs, adjusting the clip bounds as necessary
scheduler._fp16_clip_autoset(self.model.llm, latents, func, model_kwargs)
samples = scheduler(latents, func, model_kwargs, use_kv_cache=use_kv_cache, offload_kv_cache=offload_kv_cache)
samples = samples.chunk((1+num_cfg), dim=0)[0]

if self.model_cpu_offload:
self.model.to('cpu')
if not self.model.quantized:
self.model.to("cpu")

torch.cuda.empty_cache()
gc.collect()

self.vae.to(self.device)
samples = samples.to(torch.float32)
samples = samples.to(self.vae.dtype)
if self.vae.config.shift_factor is not None:
samples = samples / self.vae.config.scaling_factor + self.vae.config.shift_factor
else:
Expand Down
Loading