Skip to content

Commit 46ce3fc

Browse files
committed
v1.1
1 parent c365e96 commit 46ce3fc

File tree

4 files changed

+786
-0
lines changed

4 files changed

+786
-0
lines changed
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
4+
import os, re, time
5+
import numpy as np
6+
from PIL import Image
7+
import imageio
8+
from tqdm import tqdm
9+
import torch
10+
from einops import rearrange
11+
12+
from diffsynth import ModelManager, FlashVSRFullPipeline
13+
from utils.utils import Causal_LQ4x_Proj
14+
15+
def tensor2video(frames: torch.Tensor):
16+
frames = rearrange(frames, "C T H W -> T H W C")
17+
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
18+
frames = [Image.fromarray(frame) for frame in frames]
19+
return frames
20+
21+
def natural_key(name: str):
22+
return [int(t) if t.isdigit() else t.lower() for t in re.split(r'([0-9]+)', os.path.basename(name))]
23+
24+
def list_images_natural(folder: str):
25+
exts = ('.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG')
26+
fs = [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(exts)]
27+
fs.sort(key=natural_key)
28+
return fs
29+
30+
def largest_8n1_leq(n): # 8n+1
31+
return 0 if n < 1 else ((n - 1)//8)*8 + 1
32+
33+
def is_video(path):
34+
return os.path.isfile(path) and path.lower().endswith(('.mp4','.mov','.avi','.mkv'))
35+
36+
def pil_to_tensor_neg1_1(img: Image.Image, dtype=torch.bfloat16, device='cuda'):
37+
t = torch.from_numpy(np.asarray(img, np.uint8)).to(device=device, dtype=torch.float32) # HWC
38+
t = t.permute(2,0,1) / 255.0 * 2.0 - 1.0 # CHW in [-1,1]
39+
return t.to(dtype)
40+
41+
def save_video(frames, save_path, fps=30, quality=5):
42+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
43+
w = imageio.get_writer(save_path, fps=fps, quality=quality)
44+
for f in tqdm(frames, desc=f"Saving {os.path.basename(save_path)}"):
45+
w.append_data(np.array(f))
46+
w.close()
47+
48+
def compute_scaled_and_target_dims(w0: int, h0: int, scale: int = 4, multiple: int = 128):
49+
if w0 <= 0 or h0 <= 0:
50+
raise ValueError("invalid original size")
51+
52+
sW, sH = w0 * scale, h0 * scale
53+
tW = max(multiple, (sW // multiple) * multiple)
54+
tH = max(multiple, (sH // multiple) * multiple)
55+
return sW, sH, tW, tH
56+
57+
def upscale_then_center_crop(img: Image.Image, scale: int, tW: int, tH: int) -> Image.Image:
58+
w0, h0 = img.size
59+
sW, sH = w0 * scale, h0 * scale
60+
# 先放大
61+
up = img.resize((sW, sH), Image.BICUBIC)
62+
# 中心裁剪
63+
l = max(0, (sW - tW) // 2); t = max(0, (sH - tH) // 2)
64+
return up.crop((l, t, l + tW, t + tH))
65+
66+
def prepare_input_tensor(path: str, scale: int = 4, dtype=torch.bfloat16, device='cuda'):
67+
if os.path.isdir(path):
68+
paths0 = list_images_natural(path)
69+
if not paths0:
70+
raise FileNotFoundError(f"No images in {path}")
71+
with Image.open(paths0[0]) as _img0:
72+
w0, h0 = _img0.size
73+
N0 = len(paths0)
74+
print(f"[{os.path.basename(path)}] Original Resolution: {w0}x{h0} | Original Frames: {N0}")
75+
76+
sW, sH, tW, tH = compute_scaled_and_target_dims(w0, h0, scale=scale, multiple=128)
77+
print(f"[{os.path.basename(path)}] Scaled Resolution (x{scale}): {sW}x{sH} -> Target (128-multiple): {tW}x{tH}")
78+
79+
paths = paths0 + [paths0[-1]] * 4
80+
F = largest_8n1_leq(len(paths))
81+
if F == 0:
82+
raise RuntimeError(f"Not enough frames after padding in {path}. Got {len(paths)}.")
83+
paths = paths[:F]
84+
print(f"[{os.path.basename(path)}] Target Frames (8n-3): {F-4}")
85+
86+
frames = []
87+
for p in paths:
88+
with Image.open(p).convert('RGB') as img:
89+
img_out = upscale_then_center_crop(img, scale=scale, tW=tW, tH=tH)
90+
frames.append(pil_to_tensor_neg1_1(img_out, dtype, device))
91+
vid = torch.stack(frames, 0).permute(1,0,2,3).unsqueeze(0)
92+
fps = 30
93+
return vid, tH, tW, F, fps
94+
95+
if is_video(path):
96+
rdr = imageio.get_reader(path)
97+
first = Image.fromarray(rdr.get_data(0)).convert('RGB')
98+
w0, h0 = first.size
99+
100+
meta = {}
101+
try:
102+
meta = rdr.get_meta_data()
103+
except Exception:
104+
pass
105+
fps_val = meta.get('fps', 30)
106+
fps = int(round(fps_val)) if isinstance(fps_val, (int, float)) else 30
107+
108+
def count_frames(r):
109+
try:
110+
nf = meta.get('nframes', None)
111+
if isinstance(nf, int) and nf > 0:
112+
return nf
113+
except Exception:
114+
pass
115+
try:
116+
return r.count_frames()
117+
except Exception:
118+
n = 0
119+
try:
120+
while True:
121+
r.get_data(n); n += 1
122+
except Exception:
123+
return n
124+
125+
total = count_frames(rdr)
126+
if total <= 0:
127+
rdr.close()
128+
raise RuntimeError(f"Cannot read frames from {path}")
129+
130+
print(f"[{os.path.basename(path)}] Original Resolution: {w0}x{h0} | Original Frames: {total} | FPS: {fps}")
131+
132+
sW, sH, tW, tH = compute_scaled_and_target_dims(w0, h0, scale=scale, multiple=128)
133+
print(f"[{os.path.basename(path)}] Scaled Resolution (x{scale}): {sW}x{sH} -> Target (128-multiple): {tW}x{tH}")
134+
135+
idx = list(range(total)) + [total - 1] * 4
136+
F = largest_8n1_leq(len(idx))
137+
if F == 0:
138+
rdr.close()
139+
raise RuntimeError(f"Not enough frames after padding in {path}. Got {len(idx)}.")
140+
idx = idx[:F]
141+
print(f"[{os.path.basename(path)}] Target Frames (8n-3): {F-4}")
142+
143+
frames = []
144+
try:
145+
for i in idx:
146+
img = Image.fromarray(rdr.get_data(i)).convert('RGB')
147+
img_out = upscale_then_center_crop(img, scale=scale, tW=tW, tH=tH)
148+
frames.append(pil_to_tensor_neg1_1(img_out, dtype, device))
149+
finally:
150+
try:
151+
rdr.close()
152+
except Exception:
153+
pass
154+
155+
vid = torch.stack(frames, 0).permute(1,0,2,3).unsqueeze(0) # 1 C F H W
156+
return vid, tH, tW, F, fps
157+
158+
raise ValueError(f"Unsupported input: {path}")
159+
160+
def init_pipeline():
161+
print(torch.cuda.current_device(), torch.cuda.get_device_name(torch.cuda.current_device()))
162+
mm = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
163+
mm.load_models([
164+
"./FlashVSR/diffusion_pytorch_model_streaming_dmd.safetensors",
165+
"./FlashVSR/Wan2.1_VAE.pth",
166+
])
167+
pipe = FlashVSRFullPipeline.from_model_manager(mm, device="cuda")
168+
pipe.denoising_model().LQ_proj_in = Causal_LQ4x_Proj(in_dim=3, out_dim=1536, layer_num=1).to("cuda", dtype=torch.bfloat16)
169+
LQ_proj_in_path = "./FlashVSR/LQ_proj_in.ckpt"
170+
if os.path.exists(LQ_proj_in_path):
171+
pipe.denoising_model().LQ_proj_in.load_state_dict(torch.load(LQ_proj_in_path, map_location="cpu"), strict=True)
172+
173+
pipe.denoising_model().LQ_proj_in.to('cuda')
174+
pipe.vae.model.encoder = None
175+
pipe.vae.model.conv1 = None
176+
pipe.to('cuda'); pipe.enable_vram_management(num_persistent_param_in_dit=None)
177+
pipe.init_cross_kv(); pipe.load_models_to_device(["dit","vae"])
178+
return pipe
179+
180+
def main():
181+
RESULT_ROOT = "./results"
182+
os.makedirs(RESULT_ROOT, exist_ok=True)
183+
inputs = [
184+
"./inputs/example0.mp4",
185+
"./inputs/example1.mp4",
186+
"./inputs/example2.mp4",
187+
"./inputs/example3.mp4",
188+
]
189+
seed, scale, dtype, device = 0, 4, torch.bfloat16, 'cuda'
190+
sparse_ratio = 2.0 # Recommended: 1.5 or 2.0. 1.5 → faster; 2.0 → more stable.
191+
pipe = init_pipeline()
192+
193+
for p in inputs:
194+
torch.cuda.empty_cache(); torch.cuda.ipc_collect()
195+
name = os.path.basename(p.rstrip('/'))
196+
if name.startswith('.'):
197+
continue
198+
try:
199+
LQ, th, tw, F, fps = prepare_input_tensor(p, scale=scale, dtype=dtype, device=device)
200+
except Exception as e:
201+
print(f"[Error] {name}: {e}")
202+
continue
203+
204+
video = pipe(
205+
prompt="", negative_prompt="", cfg_scale=1.0, num_inference_steps=1, seed=seed,
206+
tiled=False,# Disable tiling: faster inference but higher VRAM usage.
207+
# Set to True for lower memory consumption at the cost of speed.
208+
LQ_video=LQ, num_frames=F, height=th, width=tw, is_full_block=False, if_buffer=True,
209+
topk_ratio=sparse_ratio*768*1280/(th*tw),
210+
kv_ratio=3.0,
211+
local_range=11, # Recommended: 9 or 11. local_range=9 → sharper details; 11 → more stable results.
212+
color_fix = True,
213+
)
214+
video = tensor2video(video)
215+
save_video(video, os.path.join(RESULT_ROOT, f"FlashVSR_v1.1_Full_{name.split('.')[0]}_seed{seed}.mp4"), fps=fps, quality=6)
216+
print("Done.")
217+
218+
if __name__ == "__main__":
219+
main()

0 commit comments

Comments
 (0)