Skip to content

Commit 4f136f8

Browse files
DoctorKeyyiyixuxu
andauthored
Add support for Ovis-Image (#12740)
* add ovis_image * fix code quality * optimize pipeline_ovis_image.py according to the feedbacks * optimize imports * add docs * make style * make style * add ovis to toctree * oops --------- Co-authored-by: YiYi Xu <yixu310@gmail.com>
1 parent edf36f5 commit 4f136f8

File tree

15 files changed

+1714
-0
lines changed

15 files changed

+1714
-0
lines changed

docs/source/en/_toctree.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,8 @@
375375
title: MochiTransformer3DModel
376376
- local: api/models/omnigen_transformer
377377
title: OmniGenTransformer2DModel
378+
- local: api/models/ovisimage_transformer2d
379+
title: OvisImageTransformer2DModel
378380
- local: api/models/pixart_transformer2d
379381
title: PixArtTransformer2DModel
380382
- local: api/models/prior_transformer
@@ -567,6 +569,8 @@
567569
title: MultiDiffusion
568570
- local: api/pipelines/omnigen
569571
title: OmniGen
572+
- local: api/pipelines/ovis_image
573+
title: Ovis-Image
570574
- local: api/pipelines/pag
571575
title: PAG
572576
- local: api/pipelines/paint_by_example
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# OvisImageTransformer2DModel
13+
14+
The model can be loaded with the following code snippet.
15+
16+
```python
17+
from diffusers import OvisImageTransformer2DModel
18+
19+
transformer = OvisImageTransformer2DModel.from_pretrained("AIDC-AI/Ovis-Image-7B", subfolder="transformer", torch_dtype=torch.bfloat16)
20+
```
21+
22+
## OvisImageTransformer2DModel
23+
24+
[[autodoc]] OvisImageTransformer2DModel
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Ovis-Image
14+
15+
![concepts](https://github.com/AIDC-AI/Ovis-Image/blob/main/docs/imgs/ovis_image_case.png)
16+
17+
Ovis-Image is a 7B text-to-image model specifically optimized for high-quality text rendering, designed to operate efficiently under stringent computational constraints.
18+
19+
[Ovis-Image Technical Report](https://arxiv.org/abs/2511.22982) from Alibaba Group, by Guo-Hua Wang, Liangfu Cao, Tianyu Cui, Minghao Fu, Xiaohao Chen, Pengxin Zhan, Jianshan Zhao, Lan Li, Bowen Fu, Jiaqi Liu, Qing-Guo Chen.
20+
21+
The abstract from the paper is:
22+
23+
*We introduce Ovis-Image, a 7B text-to-image model specifically optimized for high-quality text rendering, designed to operate efficiently under stringent computational constraints. Built upon our previous Ovis-U1 framework, Ovis-Image integrates a diffusion-based visual decoder with the stronger Ovis 2.5 multimodal backbone, leveraging a text-centric training pipeline that combines large-scale pre-training with carefully tailored post-training refinements. Despite its compact architecture, Ovis-Image achieves text rendering performance on par with significantly larger open models such as Qwen-Image and approaches closed-source systems like Seedream and GPT4o. Crucially, the model remains deployable on a single high-end GPU with moderate memory, narrowing the gap between frontier-level text rendering and practical deployment. Our results indicate that combining a strong multimodal backbone with a carefully designed, text-focused training recipe is sufficient to achieve reliable bilingual text rendering without resorting to oversized or proprietary models.*
24+
25+
**Highlights**:
26+
27+
* **Strong text rendering at a compact 7B scale**: Ovis-Image is a 7B text-to-image model that delivers text rendering quality comparable to much larger 20B-class systems such as Qwen-Image and competitive with leading closed-source models like GPT4o in text-centric scenarios, while remaining small enough to run on widely accessible hardware.
28+
* **High fidelity on text-heavy, layout-sensitive prompts**: The model excels on prompts that demand tight alignment between linguistic content and rendered typography (e.g., posters, banners, logos, UI mockups, infographics), producing legible, correctly spelled, and semantically consistent text across diverse fonts, sizes, and aspect ratios without compromising overall visual quality.
29+
* **Efficiency and deployability**: With its 7B parameter budget and streamlined architecture, Ovis-Image fits on a single high-end GPU with moderate memory, supports low-latency interactive use, and scales to batch production serving, bringing near–frontier text rendering to applications where tens-of-billions–parameter models are impractical.
30+
31+
32+
This pipeline was contributed by Ovis-Image Team. The original codebase can be found [here](https://github.com/AIDC-AI/Ovis-Image).
33+
34+
Available models:
35+
36+
| Model | Recommended dtype |
37+
|:-----:|:-----------------:|
38+
| [`AIDC-AI/Ovis-Image-7B`](https://huggingface.co/AIDC-AI/Ovis-Image-7B) | `torch.bfloat16` |
39+
40+
Refer to [this](https://huggingface.co/collections/AIDC-AI/ovis-image) collection for more information.
41+
42+
## OvisImagePipeline
43+
44+
[[autodoc]] OvisImagePipeline
45+
- all
46+
- __call__
47+
48+
## OvisImagePipelineOutput
49+
50+
[[autodoc]] pipelines.ovis_image.pipeline_output.OvisImagePipelineOutput
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
import argparse
2+
from contextlib import nullcontext
3+
4+
import safetensors.torch
5+
import torch
6+
from accelerate import init_empty_weights
7+
from huggingface_hub import hf_hub_download
8+
9+
from diffusers import OvisImageTransformer2DModel
10+
from diffusers.utils.import_utils import is_accelerate_available
11+
12+
13+
"""
14+
# Transformer
15+
16+
python scripts/convert_ovis_image_to_diffusers.py \
17+
--original_state_dict_repo_id "AIDC-AI/Ovis-Image-7B" \
18+
--filename "ovis_image.safetensors"
19+
--output_path "ovis-image" \
20+
--transformer
21+
"""
22+
23+
24+
CTX = init_empty_weights if is_accelerate_available() else nullcontext
25+
26+
parser = argparse.ArgumentParser()
27+
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
28+
parser.add_argument("--filename", default="ovis_image.safetensors", type=str)
29+
parser.add_argument("--checkpoint_path", default=None, type=str)
30+
parser.add_argument("--in_channels", type=int, default=64)
31+
parser.add_argument("--out_channels", type=int, default=None)
32+
parser.add_argument("--transformer", action="store_true")
33+
parser.add_argument("--output_path", type=str)
34+
parser.add_argument("--dtype", type=str, default="bf16")
35+
36+
args = parser.parse_args()
37+
dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32
38+
39+
40+
def load_original_checkpoint(args):
41+
if args.original_state_dict_repo_id is not None:
42+
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
43+
elif args.checkpoint_path is not None:
44+
ckpt_path = args.checkpoint_path
45+
else:
46+
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
47+
48+
original_state_dict = safetensors.torch.load_file(ckpt_path)
49+
return original_state_dict
50+
51+
52+
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
53+
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
54+
def swap_scale_shift(weight):
55+
shift, scale = weight.chunk(2, dim=0)
56+
new_weight = torch.cat([scale, shift], dim=0)
57+
return new_weight
58+
59+
60+
def convert_ovis_image_transformer_checkpoint_to_diffusers(
61+
original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0
62+
):
63+
converted_state_dict = {}
64+
65+
## time_text_embed.timestep_embedder <- time_in
66+
converted_state_dict["timestep_embedder.linear_1.weight"] = original_state_dict.pop("time_in.in_layer.weight")
67+
converted_state_dict["timestep_embedder.linear_1.bias"] = original_state_dict.pop("time_in.in_layer.bias")
68+
converted_state_dict["timestep_embedder.linear_2.weight"] = original_state_dict.pop("time_in.out_layer.weight")
69+
converted_state_dict["timestep_embedder.linear_2.bias"] = original_state_dict.pop("time_in.out_layer.bias")
70+
71+
# context_embedder
72+
converted_state_dict["context_embedder_norm.weight"] = original_state_dict.pop("semantic_txt_norm.weight")
73+
converted_state_dict["context_embedder.weight"] = original_state_dict.pop("semantic_txt_in.weight")
74+
converted_state_dict["context_embedder.bias"] = original_state_dict.pop("semantic_txt_in.bias")
75+
76+
# x_embedder
77+
converted_state_dict["x_embedder.weight"] = original_state_dict.pop("img_in.weight")
78+
converted_state_dict["x_embedder.bias"] = original_state_dict.pop("img_in.bias")
79+
80+
# double transformer blocks
81+
for i in range(num_layers):
82+
block_prefix = f"transformer_blocks.{i}."
83+
# norms.
84+
## norm1
85+
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop(
86+
f"double_blocks.{i}.img_mod.lin.weight"
87+
)
88+
converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop(
89+
f"double_blocks.{i}.img_mod.lin.bias"
90+
)
91+
## norm1_context
92+
converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop(
93+
f"double_blocks.{i}.txt_mod.lin.weight"
94+
)
95+
converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop(
96+
f"double_blocks.{i}.txt_mod.lin.bias"
97+
)
98+
# Q, K, V
99+
sample_q, sample_k, sample_v = torch.chunk(
100+
original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0
101+
)
102+
context_q, context_k, context_v = torch.chunk(
103+
original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
104+
)
105+
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
106+
original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
107+
)
108+
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
109+
original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
110+
)
111+
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
112+
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
113+
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
114+
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
115+
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
116+
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
117+
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
118+
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
119+
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
120+
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
121+
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
122+
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
123+
# qk_norm
124+
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
125+
f"double_blocks.{i}.img_attn.norm.query_norm.weight"
126+
)
127+
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
128+
f"double_blocks.{i}.img_attn.norm.key_norm.weight"
129+
)
130+
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
131+
f"double_blocks.{i}.txt_attn.norm.query_norm.weight"
132+
)
133+
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
134+
f"double_blocks.{i}.txt_attn.norm.key_norm.weight"
135+
)
136+
# ff img_mlp
137+
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = torch.cat(
138+
[
139+
original_state_dict.pop(f"double_blocks.{i}.img_mlp.up_proj.weight"),
140+
original_state_dict.pop(f"double_blocks.{i}.img_mlp.gate_proj.weight"),
141+
],
142+
dim=0,
143+
)
144+
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = torch.cat(
145+
[
146+
original_state_dict.pop(f"double_blocks.{i}.img_mlp.up_proj.bias"),
147+
original_state_dict.pop(f"double_blocks.{i}.img_mlp.gate_proj.bias"),
148+
],
149+
dim=0,
150+
)
151+
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
152+
f"double_blocks.{i}.img_mlp.down_proj.weight"
153+
)
154+
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(
155+
f"double_blocks.{i}.img_mlp.down_proj.bias"
156+
)
157+
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = torch.cat(
158+
[
159+
original_state_dict.pop(f"double_blocks.{i}.txt_mlp.up_proj.weight"),
160+
original_state_dict.pop(f"double_blocks.{i}.txt_mlp.gate_proj.weight"),
161+
],
162+
dim=0,
163+
)
164+
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = torch.cat(
165+
[
166+
original_state_dict.pop(f"double_blocks.{i}.txt_mlp.up_proj.bias"),
167+
original_state_dict.pop(f"double_blocks.{i}.txt_mlp.gate_proj.bias"),
168+
],
169+
dim=0,
170+
)
171+
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop(
172+
f"double_blocks.{i}.txt_mlp.down_proj.weight"
173+
)
174+
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop(
175+
f"double_blocks.{i}.txt_mlp.down_proj.bias"
176+
)
177+
# output projections.
178+
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
179+
f"double_blocks.{i}.img_attn.proj.weight"
180+
)
181+
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
182+
f"double_blocks.{i}.img_attn.proj.bias"
183+
)
184+
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop(
185+
f"double_blocks.{i}.txt_attn.proj.weight"
186+
)
187+
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop(
188+
f"double_blocks.{i}.txt_attn.proj.bias"
189+
)
190+
191+
# single transformer blocks
192+
for i in range(num_single_layers):
193+
block_prefix = f"single_transformer_blocks.{i}."
194+
# norm.linear <- single_blocks.0.modulation.lin
195+
converted_state_dict[f"{block_prefix}norm.linear.weight"] = original_state_dict.pop(
196+
f"single_blocks.{i}.modulation.lin.weight"
197+
)
198+
converted_state_dict[f"{block_prefix}norm.linear.bias"] = original_state_dict.pop(
199+
f"single_blocks.{i}.modulation.lin.bias"
200+
)
201+
# Q, K, V, mlp
202+
mlp_hidden_dim = int(inner_dim * mlp_ratio)
203+
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim * 2)
204+
q, k, v, mlp = torch.split(original_state_dict.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
205+
q_bias, k_bias, v_bias, mlp_bias = torch.split(
206+
original_state_dict.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
207+
)
208+
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
209+
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
210+
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
211+
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
212+
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
213+
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
214+
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
215+
converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
216+
# qk norm
217+
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
218+
f"single_blocks.{i}.norm.query_norm.weight"
219+
)
220+
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
221+
f"single_blocks.{i}.norm.key_norm.weight"
222+
)
223+
# output projections.
224+
converted_state_dict[f"{block_prefix}proj_out.weight"] = original_state_dict.pop(
225+
f"single_blocks.{i}.linear2.weight"
226+
)
227+
converted_state_dict[f"{block_prefix}proj_out.bias"] = original_state_dict.pop(
228+
f"single_blocks.{i}.linear2.bias"
229+
)
230+
231+
converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
232+
converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
233+
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
234+
original_state_dict.pop("final_layer.adaLN_modulation.1.weight")
235+
)
236+
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
237+
original_state_dict.pop("final_layer.adaLN_modulation.1.bias")
238+
)
239+
240+
return converted_state_dict
241+
242+
243+
def main(args):
244+
original_ckpt = load_original_checkpoint(args)
245+
246+
if args.transformer:
247+
num_layers = 6
248+
num_single_layers = 27
249+
inner_dim = 3072
250+
mlp_ratio = 4.0
251+
252+
converted_transformer_state_dict = convert_ovis_image_transformer_checkpoint_to_diffusers(
253+
original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio
254+
)
255+
transformer = OvisImageTransformer2DModel(in_channels=args.in_channels, out_channels=args.out_channels)
256+
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
257+
258+
print("Saving Ovis-Image Transformer in Diffusers format.")
259+
transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")
260+
261+
262+
if __name__ == "__main__":
263+
main(args)

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@
242242
"MultiAdapter",
243243
"MultiControlNetModel",
244244
"OmniGenTransformer2DModel",
245+
"OvisImageTransformer2DModel",
245246
"ParallelConfig",
246247
"PixArtTransformer2DModel",
247248
"PriorTransformer",
@@ -537,6 +538,7 @@
537538
"MochiPipeline",
538539
"MusicLDMPipeline",
539540
"OmniGenPipeline",
541+
"OvisImagePipeline",
540542
"PaintByExamplePipeline",
541543
"PIAPipeline",
542544
"PixArtAlphaPipeline",
@@ -965,6 +967,7 @@
965967
MultiAdapter,
966968
MultiControlNetModel,
967969
OmniGenTransformer2DModel,
970+
OvisImageTransformer2DModel,
968971
ParallelConfig,
969972
PixArtTransformer2DModel,
970973
PriorTransformer,
@@ -1230,6 +1233,7 @@
12301233
MochiPipeline,
12311234
MusicLDMPipeline,
12321235
OmniGenPipeline,
1236+
OvisImagePipeline,
12331237
PaintByExamplePipeline,
12341238
PIAPipeline,
12351239
PixArtAlphaPipeline,

0 commit comments

Comments
 (0)