11import os
22import inspect
3- from typing import Any , Callable , Dict , List , Optional , Union
3+ from typing import Any , Callable , Dict , List , Optional , Union , Literal
44import gc
55
66from PIL import Image
1717 scale_lora_layers ,
1818 unscale_lora_layers ,
1919)
20+ from transformers import BitsAndBytesConfig
2021from safetensors .torch import load_file
2122
2223from OmniGen import OmniGen , OmniGenProcessor , OmniGenScheduler
@@ -76,7 +77,7 @@ def __init__(
7677 self .model_cpu_offload = False
7778
7879 @classmethod
79- def from_pretrained (cls , model_name , vae_path : str = None , device = None , low_cpu_mem_usage = True ):
80+ def from_pretrained (cls , model_name , vae_path : str = None , device = None , quantization_config : Literal [ 'bnb_4bit' , 'bnb_8bit' ] | BitsAndBytesConfig = None , low_cpu_mem_usage = True ):
8081 if not os .path .exists (model_name ) or (not os .path .exists (os .path .join (model_name , 'model.safetensors' )) and model_name == "Shitao/OmniGen-v1" ):
8182 logger .info ("Model not found, downloading..." )
8283 cache_folder = os .getenv ('HF_HUB_CACHE' )
@@ -87,8 +88,16 @@ def from_pretrained(cls, model_name, vae_path: str=None, device=None, low_cpu_me
8788
8889 if device is None :
8990 device = best_available_device ()
90-
91- model = OmniGen .from_pretrained (model_name , dtype = torch .bfloat16 , low_cpu_mem_usage = low_cpu_mem_usage )
91+
92+ if isinstance (quantization_config , str ):
93+ if quantization_config == 'bnb_4bit' :
94+ quantization_config = BitsAndBytesConfig (load_in_4bit = True , bnb_4bit_compute_dtype = torch .float32 , bnb_4bit_quant_type = 'nf4' , bnb_4bit_use_double_quant = False )
95+ elif quantization_config == 'bnb_8bit' :
96+ quantization_config = BitsAndBytesConfig (load_in_8bit = True )
97+ else :
98+ raise NotImplementedError (f'Unknown `quantization_config` { quantization_config !r} ' )
99+
100+ model = OmniGen .from_pretrained (model_name , dtype = torch .bfloat16 , quantization_config = quantization_config , low_cpu_mem_usage = low_cpu_mem_usage )
92101 processor = OmniGenProcessor .from_pretrained (model_name )
93102
94103 if vae_path is None :
@@ -98,7 +107,7 @@ def from_pretrained(cls, model_name, vae_path: str=None, device=None, low_cpu_me
98107 logger .info (f"No VAE found in { model_name } , downloading stabilityai/sdxl-vae from HF" )
99108 vae_path = "stabilityai/sdxl-vae"
100109
101- vae = AutoencoderKL .from_pretrained (vae_path ). to ( device )
110+ vae = AutoencoderKL .from_pretrained (vae_path )
102111
103112 return cls (vae , model , processor , device )
104113
@@ -131,7 +140,8 @@ def move_to_device(self, data):
131140
132141 def enable_model_cpu_offload (self ):
133142 self .model_cpu_offload = True
134- self .model .to ("cpu" )
143+ if self .model .offloadable :
144+ self .model .to ("cpu" )
135145 self .vae .to ("cpu" )
136146 torch .cuda .empty_cache () # Clear VRAM
137147 gc .collect () # Run garbage collection to free system RAM
@@ -221,6 +231,7 @@ def __call__(
221231 if max_input_image_size != self .processor .max_image_size :
222232 self .processor = OmniGenProcessor (self .processor .text_tokenizer , max_image_size = max_input_image_size )
223233 self .model .to (dtype )
234+ #self.vae.to(dtype) # Uncomment this line to allow bfloat16 VAE
224235 if offload_model :
225236 self .enable_model_cpu_offload ()
226237 else :
@@ -250,12 +261,12 @@ def __call__(
250261 for temp_pixel_values in input_data ['input_pixel_values' ]:
251262 temp_input_latents = []
252263 for img in temp_pixel_values :
253- img = self .vae_encode (img .to (self .device ), dtype )
264+ img = self .vae_encode (img .to (self .vae . device , self . vae . dtype ), dtype )
254265 temp_input_latents .append (img )
255266 input_img_latents .append (temp_input_latents )
256267 else :
257268 for img in input_data ['input_pixel_values' ]:
258- img = self .vae_encode (img .to (self .device ), dtype )
269+ img = self .vae_encode (img .to (self .vae . device , self . vae . dtype ), dtype )
259270 input_img_latents .append (img )
260271 if input_images is not None and self .model_cpu_offload :
261272 self .vae .to ('cpu' )
@@ -279,7 +290,7 @@ def __call__(
279290 else :
280291 func = self .model .forward_with_cfg
281292
282- if self .model_cpu_offload :
293+ if self .model_cpu_offload and self . model . offloadable :
283294 for name , param in self .model .named_parameters ():
284295 if 'layers' in name and 'layers.0' not in name :
285296 param .data = param .data .cpu ()
@@ -294,13 +305,13 @@ def __call__(
294305 samples = scheduler (latents , func , model_kwargs , use_kv_cache = use_kv_cache , offload_kv_cache = offload_kv_cache )
295306 samples = samples .chunk ((1 + num_cfg ), dim = 0 )[0 ]
296307
297- if self .model_cpu_offload :
308+ if self .model_cpu_offload and self . model . offloadable :
298309 self .model .to ('cpu' )
299310 torch .cuda .empty_cache ()
300311 gc .collect ()
301312
302313 self .vae .to (self .device )
303- samples = samples .to (torch . float32 )
314+ samples = samples .to (self . vae . dtype )
304315 if self .vae .config .shift_factor is not None :
305316 samples = samples / self .vae .config .scaling_factor + self .vae .config .shift_factor
306317 else :
0 commit comments