Skip to content

Commit 34fe2c5

Browse files
committed
support inpaint
1 parent f25249c commit 34fe2c5

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

scripts/extension.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,14 @@ def processed_inject_image(image, info_index, iteration: int, save_path_override
161161
donor_worker = job.worker
162162
except KeyError:
163163
if job.batch_size > 0:
164-
logger.warning(f"Worker '{job.worker.uuid}' had nothing")
164+
logger.warning(f"Worker '{job.worker.uuid}' had no images")
165+
continue
166+
except TypeError as e:
167+
if job.worker.response is None:
168+
logger.error(f"worker '{job.worker.uuid}' had no response")
169+
else:
170+
logger.exception(e)
171+
continue
165172

166173
injected_to_iteration = 0
167174
images_per_iteration = Script.world.get_current_output_size()
@@ -179,6 +186,10 @@ def processed_inject_image(image, info_index, iteration: int, save_path_override
179186
else:
180187
injected_to_iteration += 1
181188

189+
if donor_worker is None:
190+
logger.critical("couldn't collect any responses, distributed will do nothing")
191+
return
192+
182193
# generate and inject grid
183194
if opts.return_grid:
184195
grid = processing.images.image_grid(processed.images, len(processed.images))

scripts/spartan/Worker.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import base64
1717
import queue
1818
from modules.shared import state as master_state
19+
from modules.api.api import encode_pil_to_base64
1920

2021

2122
class InvalidWorkerResponse(Exception):
@@ -322,6 +323,14 @@ def request(self, payload: dict, option_payload: dict, sync_options: bool):
322323
images.append(image)
323324
payload['init_images'] = images
324325

326+
# if an image mask is present
327+
image_mask = payload.get('image_mask', None)
328+
if image_mask is not None:
329+
image_b64 = encode_pil_to_base64(image_mask)
330+
image_b64 = str(image_b64, 'utf-8')
331+
payload['mask'] = image_b64
332+
del payload['image_mask']
333+
325334
# see if there is anything else wrong with serializing to payload
326335
try:
327336
json.dumps(payload)

0 commit comments

Comments
 (0)