Skip to content

Commit 99012bd

Browse files
committed
feat: add 4bit and 8bit quantization support with bitsandbytes
Add a quantization utility for HFQuantizers. Modify pipelines to accept quantization_config. Sets ground work for allow bf16 vae. Update requirements to include bitsandbytes. closes #45, closes #64
1 parent ec1a9a2 commit 99012bd

File tree

4 files changed

+69
-18
lines changed

4 files changed

+69
-18
lines changed

OmniGen/model.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
from huggingface_hub import snapshot_download
1313
from safetensors.torch import load_file
1414
from accelerate import init_empty_weights
15+
from transformers import BitsAndBytesConfig
1516

1617
from OmniGen.transformer import Phi3Config, Phi3Transformer
17-
18+
from OmniGen.utils import quantize_bnb
1819

1920
def modulate(x, shift, scale):
2021
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
@@ -187,9 +188,13 @@ def __init__(
187188

188189
self.llm = Phi3Transformer(config=transformer_config)
189190
self.llm.config.use_cache = False
191+
192+
# bnb 4bit quantized models cannot be offloaded
193+
self.offloadable = True
194+
self.quantization_config = None
190195

191196
@classmethod
192-
def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = torch.bfloat16, low_cpu_mem_usage: bool = True,):
197+
def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = torch.bfloat16, quantization_config: BitsAndBytesConfig = None, low_cpu_mem_usage: bool = True,):
193198
if not os.path.exists(model_name):
194199
cache_folder = os.getenv('HF_HUB_CACHE')
195200
model_name = snapshot_download(repo_id=model_name,
@@ -201,22 +206,30 @@ def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = torch
201206
model_path = os.path.join(model_name, 'model.pt')
202207
ckpt = torch.load(model_path, map_location='cpu')
203208
else:
204-
print("Loading safetensors")
209+
#print("Loading safetensors")
205210
ckpt = load_file(model_path, 'cpu')
206211

207212
if low_cpu_mem_usage:
208213
with init_empty_weights():
209214
config = Phi3Config.from_pretrained(model_name)
210215
model = cls(config)
211-
212-
model.load_state_dict(ckpt, assign=True)
213-
model = model.to(dtype)
216+
217+
if quantization_config:
218+
model = quantize_bnb(model, ckpt, quantization_config=quantization_config, pre_quantized=False)
219+
if getattr(quantization_config, 'load_in_4bit', None):
220+
model.offloadable = False
221+
model.quantization_config = quantization_config
222+
else:
223+
model.load_state_dict(ckpt, assign=True)
214224
else:
225+
if quantization_config:
226+
raise ValueError('Quantization not supported for `low_cpu_mem_usage=False`.')
227+
215228
config = Phi3Config.from_pretrained(model_name)
216229
model = cls(config)
217230
model.load_state_dict(ckpt)
218-
model = model.to(dtype)
219231

232+
model = model.to(dtype)
220233
del ckpt
221234
torch.cuda.empty_cache()
222235
gc.collect()

OmniGen/pipeline.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import inspect
3-
from typing import Any, Callable, Dict, List, Optional, Union
3+
from typing import Any, Callable, Dict, List, Optional, Union, Literal
44
import gc
55

66
from PIL import Image
@@ -17,6 +17,7 @@
1717
scale_lora_layers,
1818
unscale_lora_layers,
1919
)
20+
from transformers import BitsAndBytesConfig
2021
from safetensors.torch import load_file
2122

2223
from OmniGen import OmniGen, OmniGenProcessor, OmniGenScheduler
@@ -76,7 +77,7 @@ def __init__(
7677
self.model_cpu_offload = False
7778

7879
@classmethod
79-
def from_pretrained(cls, model_name, vae_path: str=None, device=None, low_cpu_mem_usage=True):
80+
def from_pretrained(cls, model_name, vae_path: str=None, device=None, quantization_config:Literal['bnb_4bit','bnb_8bit']|BitsAndBytesConfig=None, low_cpu_mem_usage=True):
8081
if not os.path.exists(model_name) or (not os.path.exists(os.path.join(model_name, 'model.safetensors')) and model_name == "Shitao/OmniGen-v1"):
8182
logger.info("Model not found, downloading...")
8283
cache_folder = os.getenv('HF_HUB_CACHE')
@@ -87,8 +88,16 @@ def from_pretrained(cls, model_name, vae_path: str=None, device=None, low_cpu_me
8788

8889
if device is None:
8990
device = best_available_device()
90-
91-
model = OmniGen.from_pretrained(model_name, dtype=torch.bfloat16, low_cpu_mem_usage=low_cpu_mem_usage)
91+
92+
if isinstance(quantization_config, str):
93+
if quantization_config == 'bnb_4bit':
94+
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float32, bnb_4bit_quant_type='nf4', bnb_4bit_use_double_quant=False)
95+
elif quantization_config == 'bnb_8bit':
96+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
97+
else:
98+
raise NotImplementedError(f'Unknown `quantization_config` {quantization_config!r}')
99+
100+
model = OmniGen.from_pretrained(model_name, dtype=torch.bfloat16, quantization_config=quantization_config, low_cpu_mem_usage=low_cpu_mem_usage)
92101
processor = OmniGenProcessor.from_pretrained(model_name)
93102

94103
if vae_path is None:
@@ -98,7 +107,7 @@ def from_pretrained(cls, model_name, vae_path: str=None, device=None, low_cpu_me
98107
logger.info(f"No VAE found in {model_name}, downloading stabilityai/sdxl-vae from HF")
99108
vae_path = "stabilityai/sdxl-vae"
100109

101-
vae = AutoencoderKL.from_pretrained(vae_path).to(device)
110+
vae = AutoencoderKL.from_pretrained(vae_path)
102111

103112
return cls(vae, model, processor, device)
104113

@@ -131,7 +140,8 @@ def move_to_device(self, data):
131140

132141
def enable_model_cpu_offload(self):
133142
self.model_cpu_offload = True
134-
self.model.to("cpu")
143+
if self.model.offloadable:
144+
self.model.to("cpu")
135145
self.vae.to("cpu")
136146
torch.cuda.empty_cache() # Clear VRAM
137147
gc.collect() # Run garbage collection to free system RAM
@@ -221,6 +231,7 @@ def __call__(
221231
if max_input_image_size != self.processor.max_image_size:
222232
self.processor = OmniGenProcessor(self.processor.text_tokenizer, max_image_size=max_input_image_size)
223233
self.model.to(dtype)
234+
#self.vae.to(dtype) # Uncomment this line to allow bfloat16 VAE
224235
if offload_model:
225236
self.enable_model_cpu_offload()
226237
else:
@@ -250,12 +261,12 @@ def __call__(
250261
for temp_pixel_values in input_data['input_pixel_values']:
251262
temp_input_latents = []
252263
for img in temp_pixel_values:
253-
img = self.vae_encode(img.to(self.device), dtype)
264+
img = self.vae_encode(img.to(self.vae.device, self.vae.dtype), dtype)
254265
temp_input_latents.append(img)
255266
input_img_latents.append(temp_input_latents)
256267
else:
257268
for img in input_data['input_pixel_values']:
258-
img = self.vae_encode(img.to(self.device), dtype)
269+
img = self.vae_encode(img.to(self.vae.device, self.vae.dtype), dtype)
259270
input_img_latents.append(img)
260271
if input_images is not None and self.model_cpu_offload:
261272
self.vae.to('cpu')
@@ -279,7 +290,7 @@ def __call__(
279290
else:
280291
func = self.model.forward_with_cfg
281292

282-
if self.model_cpu_offload:
293+
if self.model_cpu_offload and self.model.offloadable:
283294
for name, param in self.model.named_parameters():
284295
if 'layers' in name and 'layers.0' not in name:
285296
param.data = param.data.cpu()
@@ -294,13 +305,13 @@ def __call__(
294305
samples = scheduler(latents, func, model_kwargs, use_kv_cache=use_kv_cache, offload_kv_cache=offload_kv_cache)
295306
samples = samples.chunk((1+num_cfg), dim=0)[0]
296307

297-
if self.model_cpu_offload:
308+
if self.model_cpu_offload and self.model.offloadable:
298309
self.model.to('cpu')
299310
torch.cuda.empty_cache()
300311
gc.collect()
301312

302313
self.vae.to(self.device)
303-
samples = samples.to(torch.float32)
314+
samples = samples.to(self.vae.dtype)
304315
if self.vae.config.shift_factor is not None:
305316
samples = samples / self.vae.config.scaling_factor + self.vae.config.shift_factor
306317
else:

OmniGen/utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1+
import gc
12
import logging
23

34
from PIL import Image
45
import torch
56
import numpy as np
67

8+
from transformers import BitsAndBytesConfig
9+
from transformers.quantizers import AutoHfQuantizer
10+
from transformers.integrations import replace_with_bnb_linear, set_module_quantized_tensor_to_device
11+
712
def create_logger(logging_dir):
813
"""
914
Create a logger that writes to a log file and stdout.
@@ -108,3 +113,24 @@ def vae_encode_list(vae, x, weight_dtype):
108113
latents.append(img)
109114
return latents
110115

116+
117+
118+
@torch.no_grad()
119+
def quantize_bnb(meta_model, state_dict:dict, quantization_config:BitsAndBytesConfig, pre_quantized=False):
120+
# from transformers.integrations import get_keys_to_not_convert
121+
122+
quantizer = AutoHfQuantizer.from_config(quantization_config, pre_quantized=pre_quantized)
123+
no_convert = [] #get_keys_to_not_convert(meta_model.llm) # might be worth investigating
124+
125+
model = replace_with_bnb_linear(meta_model, modules_to_not_convert=no_convert, quantization_config=quantizer.quantization_config)
126+
127+
for param_name, param in state_dict.items():
128+
if not quantizer.check_quantized_param(model, param, param_name, state_dict):
129+
set_module_quantized_tensor_to_device(model, param_name, device=0, value=param)
130+
else:
131+
quantizer.create_quantized_param(model, param, param_name, target_device=0, state_dict=state_dict)
132+
133+
del state_dict
134+
torch.cuda.empty_cache()
135+
gc.collect()
136+
return model

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ pillow==10.2.0
88
peft==0.13.2
99
diffusers==0.30.3
1010
timm==0.9.16
11+
bitsandbytes==0.44.1

0 commit comments

Comments
 (0)