Skip to content

Commit 363e836

Browse files
committed
merge dev
2 parents 2251f61 + e8fcec0 commit 363e836

File tree

8 files changed

+730
-280
lines changed

8 files changed

+730
-280
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
*.pyc
3+
workers.json
4+
config.json

preload.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
import os
2+
from pathlib import Path
3+
from inspect import getsourcefile
4+
from os.path import abspath
5+
16
def preload(parser):
27
parser.add_argument(
38
"--distributed-remotes",
@@ -23,3 +28,11 @@ def preload(parser):
2328
help="Enable debug information",
2429
action="store_true"
2530
)
31+
extension_path = Path(abspath(getsourcefile(lambda: 0))).parent
32+
config_path = extension_path.joinpath('distributed-config.json')
33+
# add config file
34+
parser.add_argument(
35+
"--distributed-config",
36+
help="config file to load / save, default: $WEBUI_PATH/distributed-config.json",
37+
default=config_path
38+
)

scripts/extension.py

Lines changed: 92 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,22 @@
1414
import urllib3
1515
import copy
1616
from modules.images import save_image
17-
from modules.shared import cmd_opts
17+
from modules.shared import opts, cmd_opts
18+
from modules.shared import state as webui_state
1819
import time
1920
from scripts.spartan.World import World, WorldAlreadyInitialized
2021
from scripts.spartan.UI import UI
21-
from modules.shared import opts
2222
from scripts.spartan.shared import logger
2323
from scripts.spartan.control_net import pack_control_net
2424
from modules.processing import fix_seed, Processed
25+
import signal
26+
import sys
2527

28+
old_sigint_handler = signal.getsignal(signal.SIGINT)
29+
old_sigterm_handler = signal.getsignal(signal.SIGTERM)
2630

27-
# TODO implement SSDP advertisement of some sort in sdwui api to allow extension to automatically discover workers?
28-
# TODO see if the current api has some sort of UUID generation functionality.
2931

32+
# TODO implement advertisement of some sort in sdwui api to allow extension to automatically discover workers?
3033
# noinspection PyMissingOrEmptyDocstring
3134
class Script(scripts.Script):
3235
worker_threads: List[Thread] = []
@@ -39,6 +42,7 @@ class Script(scripts.Script):
3942
first_run = True
4043
master_start = None
4144
runs_since_init = 0
45+
name = "distributed"
4246

4347
if verify_remotes is False:
4448
logger.warning(f"You have chosen to forego the verification of worker TLS certificates")
@@ -47,8 +51,17 @@ class Script(scripts.Script):
4751
# build world
4852
world = World(initial_payload=None, verify_remotes=verify_remotes)
4953
# add workers to the world
50-
for worker in cmd_opts.distributed_remotes:
51-
world.add_worker(uuid=worker[0], address=worker[1], port=worker[2])
54+
world.load_config()
55+
if cmd_opts.distributed_remotes is not None and len(cmd_opts.distributed_remotes) > 0:
56+
logger.warning(f"--distributed-remotes is deprecated and may be removed in the future\n"
57+
"gui/external modification of {world.config_path} will be prioritized going forward")
58+
59+
for worker in cmd_opts.distributed_remotes:
60+
world.add_worker(uuid=worker[0], address=worker[1], port=worker[2], tls=False)
61+
world.save_config()
62+
# do an early check to see which workers are online
63+
logger.info("doing initial ping sweep to see which workers are reachable")
64+
world.ping_remotes(indiscriminate=True)
5265

5366
def title(self):
5467
return "Distribute"
@@ -59,22 +72,25 @@ def show(self, is_img2img):
5972

6073
def ui(self, is_img2img):
6174
extension_ui = UI(script=Script, world=Script.world)
62-
extension_ui.create_root()
75+
root, api_exposed = extension_ui.create_ui()
76+
77+
# return some components that should be exposed to the api
78+
return api_exposed
6379

6480
@staticmethod
6581
def add_to_gallery(processed, p):
6682
"""adds generated images to the image gallery after waiting for all workers to finish"""
83+
webui_state.textinfo = "Distributed - injecting images"
6784

68-
def processed_inject_image(image, info_index, iteration: int, save_path_override=None, grid=False, response=None):
69-
image_params: json = response["parameters"]
85+
def processed_inject_image(image, info_index, save_path_override=None, grid=False, response=None):
86+
image_params: json = response['parameters']
7087
image_info_post: json = json.loads(response["info"]) # image info known after processing
7188
num_response_images = image_params["batch_size"] * image_params["n_iter"]
7289

7390
seed = None
7491
subseed = None
7592
negative_prompt = None
7693

77-
7894
try:
7995
if num_response_images > 1:
8096
seed = image_info_post['all_seeds'][info_index]
@@ -84,10 +100,10 @@ def processed_inject_image(image, info_index, iteration: int, save_path_override
84100
seed = image_info_post['seed']
85101
subseed = image_info_post['subseed']
86102
negative_prompt = image_info_post['negative_prompt']
87-
except Exception:
103+
except IndexError:
88104
# like with controlnet masks, there isn't always full post-gen info, so we use the first images'
89-
logger.debug(f"Image at index {i} for '{worker.uuid}' was missing some post-generation data")
90-
processed_inject_image(image=image, info_index=0, iteration=iteration)
105+
logger.debug(f"Image at index {i} for '{job.worker.label}' was missing some post-generation data")
106+
processed_inject_image(image=image, info_index=0, response=response)
91107
return
92108

93109
processed.all_seeds.append(seed)
@@ -105,20 +121,18 @@ def processed_inject_image(image, info_index, iteration: int, save_path_override
105121
if p.n_iter > 1: # if splitting by batch count
106122
num_remote_images *= p.n_iter - 1
107123

108-
logger.debug(f"iteration {iteration}/{p.n_iter}, image {true_image_pos + 1}/{Script.world.total_batch_size * p.n_iter}, info-index: {info_index}")
124+
logger.debug(f"image {true_image_pos + 1}/{Script.world.total_batch_size * p.n_iter}, "
125+
f"info-index: {info_index}")
109126

110127
if Script.world.thin_client_mode:
111128
p.all_negative_prompts = processed.all_negative_prompts
112129

113-
info_text = processing.create_infotext(
114-
p=p,
115-
all_prompts=processed.all_prompts,
116-
all_seeds=processed.all_seeds,
117-
all_subseeds=processed.all_subseeds,
118-
# comments=[""], # unimplemented upstream :(
119-
position_in_batch=true_image_pos if not grid else 0,
120-
iteration=0
121-
)
130+
try:
131+
info_text = image_info_post['infotexts'][i]
132+
except IndexError:
133+
if not grid:
134+
logger.warning(f"image {true_image_pos + 1} was missing info-text")
135+
info_text = processed.infotexts[0]
122136
processed.infotexts.append(info_text)
123137

124138
# automatically save received image to local disk if desired
@@ -146,45 +160,36 @@ def processed_inject_image(image, info_index, iteration: int, save_path_override
146160

147161
# some worker which we know has a good response that we can use for generating the grid
148162
donor_worker = None
149-
spoofed_iteration = p.n_iter
150163
for job in Script.world.jobs:
151164
if job.batch_size < 1 or job.worker.master:
152165
continue
153166

154167
try:
155168
images: json = job.worker.response["images"]
156169
# if we for some reason get more than we asked for
157-
if job.batch_size < len(images):
158-
logger.debug(f"Requested {job.batch_size} images from '{job.worker.uuid}', got {len(images)}")
170+
if (job.batch_size * p.n_iter) < len(images):
171+
logger.debug(f"Requested {job.batch_size} image(s) from '{job.worker.label}', got {len(images)}")
159172

160173
if donor_worker is None:
161174
donor_worker = job.worker
162175
except KeyError:
163176
if job.batch_size > 0:
164-
logger.warning(f"Worker '{job.worker.uuid}' had no images")
177+
logger.warning(f"Worker '{job.worker.label}' had no images")
165178
continue
166179
except TypeError as e:
167180
if job.worker.response is None:
168-
logger.error(f"worker '{job.worker.uuid}' had no response")
181+
logger.error(f"worker '{job.worker.label}' had no response")
169182
else:
170183
logger.exception(e)
171184
continue
172185

173-
injected_to_iteration = 0
174-
images_per_iteration = Script.world.get_current_output_size()
175186
# visibly add work from workers to the image gallery
176187
for i in range(0, len(images)):
177188
image_bytes = base64.b64decode(images[i])
178189
image = Image.open(io.BytesIO(image_bytes))
179190

180191
# inject image
181-
processed_inject_image(image=image, info_index=i, iteration=spoofed_iteration, response=job.worker.response)
182-
183-
if injected_to_iteration >= images_per_iteration - 1:
184-
spoofed_iteration += 1
185-
injected_to_iteration = 0
186-
else:
187-
injected_to_iteration += 1
192+
processed_inject_image(image=image, info_index=i, response=job.worker.response)
188193

189194
if donor_worker is None:
190195
logger.critical("couldn't collect any responses, distributed will do nothing")
@@ -197,7 +202,6 @@ def processed_inject_image(image, info_index, iteration: int, save_path_override
197202
image=grid,
198203
info_index=0,
199204
save_path_override=p.outpath_grids,
200-
iteration=spoofed_iteration,
201205
grid=True,
202206
response=donor_worker.response
203207
)
@@ -228,14 +232,12 @@ def initialize(initial_payload):
228232
# runs every time the generate button is hit
229233
def run(self, p, *args):
230234
current_thread().name = "distributed_main"
231-
232-
if cmd_opts.distributed_remotes is None:
233-
raise RuntimeError("Distributed - No remotes passed. (Try using `--distributed-remotes`?)")
234-
235235
Script.initialize(initial_payload=p)
236236

237237
# strip scripts that aren't yet supported and warn user
238238
packed_script_args: List[dict] = [] # list of api formatted per-script argument objects
239+
# { "script_name": { "args": ["value1", "value2", ...] }
240+
incompat_list = []
239241
for script in p.scripts.scripts:
240242
if script.alwayson is not True:
241243
continue
@@ -256,16 +258,34 @@ def run(self, p, *args):
256258

257259
continue
258260
else:
261+
# other scripts to pack
262+
# args_script_pack = {}
263+
# args_script_pack[title] = {"args": []}
264+
# for arg in p.script_args[script.args_from:script.args_to]:
265+
# args_script_pack[title]["args"].append(arg)
266+
# packed_script_args.append(args_script_pack)
259267
# https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/issues/12#issuecomment-1480382514
260268
if Script.runs_since_init < 1:
261-
logger.warning(f"Distributed doesn't yet support '{title}'")
269+
incompat_list.append(title)
270+
271+
if Script.runs_since_init < 1 and len(incompat_list) >= 1:
272+
m = "Distributed doesn't yet support:"
273+
for i in range(0, len(incompat_list)):
274+
m += f" {incompat_list[i]}"
275+
if i < len(incompat_list) - 1:
276+
m += ","
277+
logger.warning(m)
262278

263279
# encapsulating the request object within a txt2imgreq object is deprecated and no longer works
264280
# see test/basic_features/txt2img_test.py for an example
265281
payload = copy.copy(p.__dict__)
266282
payload['batch_size'] = Script.world.default_batch_size()
267283
payload['scripts'] = None
268-
del payload['script_args']
284+
try:
285+
del payload['script_args']
286+
except KeyError:
287+
del payload['script_args_value']
288+
269289

270290
payload['alwayson_scripts'] = {}
271291
for packed in packed_script_args:
@@ -279,7 +299,7 @@ def run(self, p, *args):
279299
# TODO api for some reason returns 200 even if something failed to be set.
280300
# for now we may have to make redundant GET requests to check if actually successful...
281301
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/8146
282-
name = re.sub(r'\s?\[[^\]]*\]$', '', opts.data["sd_model_checkpoint"])
302+
name = re.sub(r'\s?\[[^]]*]$', '', opts.data["sd_model_checkpoint"])
283303
vae = opts.data["sd_vae"]
284304
option_payload = {
285305
"sd_model_checkpoint": name,
@@ -297,7 +317,9 @@ def run(self, p, *args):
297317
return
298318

299319
for job in Script.world.jobs:
300-
payload_temp = copy.deepcopy(payload)
320+
payload_temp = copy.copy(payload)
321+
del payload_temp['scripts_value']
322+
payload_temp = copy.deepcopy(payload_temp)
301323

302324
if job.worker.master:
303325
started_jobs.append(job)
@@ -311,14 +333,17 @@ def run(self, p, *args):
311333
payload_temp['batch_size'] = job.batch_size
312334
payload_temp['subseed'] += prior_images
313335
payload_temp['seed'] += prior_images if payload_temp['subseed_strength'] == 0 else 0
314-
logger.debug(f"'{job.worker.uuid}' job's given starting seed is {payload_temp['seed']} with {prior_images} coming before it")
336+
logger.debug(
337+
f"'{job.worker.label}' job's given starting seed is "
338+
f"{payload_temp['seed']} with {prior_images} coming before it")
315339

316340
if job.worker.loaded_model != name or job.worker.loaded_vae != vae:
317341
sync = True
318342
job.worker.loaded_model = name
319343
job.worker.loaded_vae = vae
320344

321-
t = Thread(target=job.worker.request, args=(payload_temp, option_payload, sync, ), name=f"{job.worker.uuid}_request")
345+
t = Thread(target=job.worker.request, args=(payload_temp, option_payload, sync,),
346+
name=f"{job.worker.label}_request")
322347

323348
t.start()
324349
Script.worker_threads.append(t)
@@ -341,8 +366,26 @@ def run(self, p, *args):
341366
processed.infotexts = []
342367
processed.prompt = None
343368
else:
344-
processed = processing.process_images(p, *args)
369+
processed = processing.process_images(p)
345370

346371
Script.add_to_gallery(processed, p)
347372
Script.runs_since_init += 1
348373
return processed
374+
375+
@staticmethod
376+
def signal_handler(sig, frame):
377+
logger.debug("handling interrupt signal")
378+
# do cleanup
379+
Script.world.save_config()
380+
381+
if sig == signal.SIGINT:
382+
if callable(old_sigint_handler):
383+
old_sigint_handler(sig, frame)
384+
else:
385+
if callable(old_sigterm_handler):
386+
old_sigterm_handler(sig, frame)
387+
else:
388+
sys.exit(0)
389+
390+
signal.signal(signal.SIGINT, signal_handler)
391+
signal.signal(signal.SIGTERM, signal_handler)

0 commit comments

Comments
 (0)