diff --git a/.gitignore b/.gitignore index 79174c8..ab24571 100644 --- a/.gitignore +++ b/.gitignore @@ -172,3 +172,6 @@ local.ipynb convert_to_safetensor.py test_memory_time.py test.py + +# cog +.cog/ \ No newline at end of file diff --git a/predict.py b/predict.py index 8ef6129..78e9c9f 100644 --- a/predict.py +++ b/predict.py @@ -5,6 +5,7 @@ import subprocess import time import sys +from typing import List from cog import BasePredictor, Input, Path from PIL import Image @@ -96,31 +97,34 @@ def predict( description="Automatically adjust the output image size to be same as input image size. For editing and controlnet task, it can make sure the output image has the same size as input image leading to better performance", default=False, ), - ) -> Path: + num_images_per_prompt: int = Input(description="The number of images to generate for the given inputs", default=1, ge=1, le=10), + ) -> List[Path]: """Run a single prediction on the model""" if seed is None: seed = int.from_bytes(os.urandom(2), "big") print(f"Using seed: {seed}") input_images = [str(img) for img in [img1, img2, img3] if img is not None] - - output = self.pipe( - prompt=prompt, - input_images=None if len(input_images) == 0 else input_images, - height=height, - width=width, - guidance_scale=guidance_scale, - img_guidance_scale=img_guidance_scale, - num_inference_steps=inference_steps, - separate_cfg_infer=separate_cfg_infer, - use_kv_cache=True, - offload_kv_cache=True, - offload_model=offload_model, - use_input_image_size_as_output=use_input_image_size_as_output, - seed=seed, - max_input_image_size=max_input_image_size, - ) - img = output[0] - out_path = "/tmp/out.png" - img.save(out_path) - return Path(out_path) + outputs = [] + for i in range(num_images_per_prompt): + output = self.pipe( + prompt=prompt, + input_images=None if len(input_images) == 0 else input_images, + height=height, + width=width, + guidance_scale=guidance_scale, + img_guidance_scale=img_guidance_scale, + num_inference_steps=inference_steps, + separate_cfg_infer=separate_cfg_infer, + use_kv_cache=True, + offload_kv_cache=True, + offload_model=offload_model, + use_input_image_size_as_output=use_input_image_size_as_output, + seed=seed, + max_input_image_size=max_input_image_size + ) + img = output[0] + out_path = f"/tmp/out_{i}.png" + img.save(out_path) + outputs.append(Path(out_path)) + return outputs