1414import urllib3
1515import copy
1616from modules .images import save_image
17- from modules .shared import cmd_opts
17+ from modules .shared import opts , cmd_opts
18+ from modules .shared import state as webui_state
1819import time
1920from scripts .spartan .World import World , WorldAlreadyInitialized
2021from scripts .spartan .UI import UI
21- from modules .shared import opts
2222from scripts .spartan .shared import logger
2323from scripts .spartan .control_net import pack_control_net
2424from modules .processing import fix_seed , Processed
25+ import signal
26+ import sys
2527
28+ old_sigint_handler = signal .getsignal (signal .SIGINT )
29+ old_sigterm_handler = signal .getsignal (signal .SIGTERM )
2630
27- # TODO implement SSDP advertisement of some sort in sdwui api to allow extension to automatically discover workers?
28- # TODO see if the current api has some sort of UUID generation functionality.
2931
32+ # TODO implement advertisement of some sort in sdwui api to allow extension to automatically discover workers?
3033# noinspection PyMissingOrEmptyDocstring
3134class Script (scripts .Script ):
3235 worker_threads : List [Thread ] = []
@@ -39,6 +42,7 @@ class Script(scripts.Script):
3942 first_run = True
4043 master_start = None
4144 runs_since_init = 0
45+ name = "distributed"
4246
4347 if verify_remotes is False :
4448 logger .warning (f"You have chosen to forego the verification of worker TLS certificates" )
@@ -47,8 +51,17 @@ class Script(scripts.Script):
4751 # build world
4852 world = World (initial_payload = None , verify_remotes = verify_remotes )
4953 # add workers to the world
50- for worker in cmd_opts .distributed_remotes :
51- world .add_worker (uuid = worker [0 ], address = worker [1 ], port = worker [2 ])
54+ world .load_config ()
55+ if cmd_opts .distributed_remotes is not None and len (cmd_opts .distributed_remotes ) > 0 :
56+ logger .warning (f"--distributed-remotes is deprecated and may be removed in the future\n "
57+ "gui/external modification of {world.config_path} will be prioritized going forward" )
58+
59+ for worker in cmd_opts .distributed_remotes :
60+ world .add_worker (uuid = worker [0 ], address = worker [1 ], port = worker [2 ], tls = False )
61+ world .save_config ()
62+ # do an early check to see which workers are online
63+ logger .info ("doing initial ping sweep to see which workers are reachable" )
64+ world .ping_remotes (indiscriminate = True )
5265
5366 def title (self ):
5467 return "Distribute"
@@ -59,22 +72,25 @@ def show(self, is_img2img):
5972
6073 def ui (self , is_img2img ):
6174 extension_ui = UI (script = Script , world = Script .world )
62- extension_ui .create_root ()
75+ root , api_exposed = extension_ui .create_ui ()
76+
77+ # return some components that should be exposed to the api
78+ return api_exposed
6379
6480 @staticmethod
6581 def add_to_gallery (processed , p ):
6682 """adds generated images to the image gallery after waiting for all workers to finish"""
83+ webui_state .textinfo = "Distributed - injecting images"
6784
68- def processed_inject_image (image , info_index , iteration : int , save_path_override = None , grid = False , response = None ):
69- image_params : json = response [" parameters" ]
85+ def processed_inject_image (image , info_index , save_path_override = None , grid = False , response = None ):
86+ image_params : json = response [' parameters' ]
7087 image_info_post : json = json .loads (response ["info" ]) # image info known after processing
7188 num_response_images = image_params ["batch_size" ] * image_params ["n_iter" ]
7289
7390 seed = None
7491 subseed = None
7592 negative_prompt = None
7693
77-
7894 try :
7995 if num_response_images > 1 :
8096 seed = image_info_post ['all_seeds' ][info_index ]
@@ -84,10 +100,10 @@ def processed_inject_image(image, info_index, iteration: int, save_path_override
84100 seed = image_info_post ['seed' ]
85101 subseed = image_info_post ['subseed' ]
86102 negative_prompt = image_info_post ['negative_prompt' ]
87- except Exception :
103+ except IndexError :
88104 # like with controlnet masks, there isn't always full post-gen info, so we use the first images'
89- logger .debug (f"Image at index { i } for '{ worker .uuid } ' was missing some post-generation data" )
90- processed_inject_image (image = image , info_index = 0 , iteration = iteration )
105+ logger .debug (f"Image at index { i } for '{ job . worker .label } ' was missing some post-generation data" )
106+ processed_inject_image (image = image , info_index = 0 , response = response )
91107 return
92108
93109 processed .all_seeds .append (seed )
@@ -105,20 +121,18 @@ def processed_inject_image(image, info_index, iteration: int, save_path_override
105121 if p .n_iter > 1 : # if splitting by batch count
106122 num_remote_images *= p .n_iter - 1
107123
108- logger .debug (f"iteration { iteration } /{ p .n_iter } , image { true_image_pos + 1 } /{ Script .world .total_batch_size * p .n_iter } , info-index: { info_index } " )
124+ logger .debug (f"image { true_image_pos + 1 } /{ Script .world .total_batch_size * p .n_iter } , "
125+ f"info-index: { info_index } " )
109126
110127 if Script .world .thin_client_mode :
111128 p .all_negative_prompts = processed .all_negative_prompts
112129
113- info_text = processing .create_infotext (
114- p = p ,
115- all_prompts = processed .all_prompts ,
116- all_seeds = processed .all_seeds ,
117- all_subseeds = processed .all_subseeds ,
118- # comments=[""], # unimplemented upstream :(
119- position_in_batch = true_image_pos if not grid else 0 ,
120- iteration = 0
121- )
130+ try :
131+ info_text = image_info_post ['infotexts' ][i ]
132+ except IndexError :
133+ if not grid :
134+ logger .warning (f"image { true_image_pos + 1 } was missing info-text" )
135+ info_text = processed .infotexts [0 ]
122136 processed .infotexts .append (info_text )
123137
124138 # automatically save received image to local disk if desired
@@ -146,45 +160,36 @@ def processed_inject_image(image, info_index, iteration: int, save_path_override
146160
147161 # some worker which we know has a good response that we can use for generating the grid
148162 donor_worker = None
149- spoofed_iteration = p .n_iter
150163 for job in Script .world .jobs :
151164 if job .batch_size < 1 or job .worker .master :
152165 continue
153166
154167 try :
155168 images : json = job .worker .response ["images" ]
156169 # if we for some reason get more than we asked for
157- if job .batch_size < len (images ):
158- logger .debug (f"Requested { job .batch_size } images from '{ job .worker .uuid } ', got { len (images )} " )
170+ if ( job .batch_size * p . n_iter ) < len (images ):
171+ logger .debug (f"Requested { job .batch_size } image(s) from '{ job .worker .label } ', got { len (images )} " )
159172
160173 if donor_worker is None :
161174 donor_worker = job .worker
162175 except KeyError :
163176 if job .batch_size > 0 :
164- logger .warning (f"Worker '{ job .worker .uuid } ' had no images" )
177+ logger .warning (f"Worker '{ job .worker .label } ' had no images" )
165178 continue
166179 except TypeError as e :
167180 if job .worker .response is None :
168- logger .error (f"worker '{ job .worker .uuid } ' had no response" )
181+ logger .error (f"worker '{ job .worker .label } ' had no response" )
169182 else :
170183 logger .exception (e )
171184 continue
172185
173- injected_to_iteration = 0
174- images_per_iteration = Script .world .get_current_output_size ()
175186 # visibly add work from workers to the image gallery
176187 for i in range (0 , len (images )):
177188 image_bytes = base64 .b64decode (images [i ])
178189 image = Image .open (io .BytesIO (image_bytes ))
179190
180191 # inject image
181- processed_inject_image (image = image , info_index = i , iteration = spoofed_iteration , response = job .worker .response )
182-
183- if injected_to_iteration >= images_per_iteration - 1 :
184- spoofed_iteration += 1
185- injected_to_iteration = 0
186- else :
187- injected_to_iteration += 1
192+ processed_inject_image (image = image , info_index = i , response = job .worker .response )
188193
189194 if donor_worker is None :
190195 logger .critical ("couldn't collect any responses, distributed will do nothing" )
@@ -197,7 +202,6 @@ def processed_inject_image(image, info_index, iteration: int, save_path_override
197202 image = grid ,
198203 info_index = 0 ,
199204 save_path_override = p .outpath_grids ,
200- iteration = spoofed_iteration ,
201205 grid = True ,
202206 response = donor_worker .response
203207 )
@@ -228,14 +232,12 @@ def initialize(initial_payload):
228232 # runs every time the generate button is hit
229233 def run (self , p , * args ):
230234 current_thread ().name = "distributed_main"
231-
232- if cmd_opts .distributed_remotes is None :
233- raise RuntimeError ("Distributed - No remotes passed. (Try using `--distributed-remotes`?)" )
234-
235235 Script .initialize (initial_payload = p )
236236
237237 # strip scripts that aren't yet supported and warn user
238238 packed_script_args : List [dict ] = [] # list of api formatted per-script argument objects
239+ # { "script_name": { "args": ["value1", "value2", ...] }
240+ incompat_list = []
239241 for script in p .scripts .scripts :
240242 if script .alwayson is not True :
241243 continue
@@ -256,16 +258,34 @@ def run(self, p, *args):
256258
257259 continue
258260 else :
261+ # other scripts to pack
262+ # args_script_pack = {}
263+ # args_script_pack[title] = {"args": []}
264+ # for arg in p.script_args[script.args_from:script.args_to]:
265+ # args_script_pack[title]["args"].append(arg)
266+ # packed_script_args.append(args_script_pack)
259267 # https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/issues/12#issuecomment-1480382514
260268 if Script .runs_since_init < 1 :
261- logger .warning (f"Distributed doesn't yet support '{ title } '" )
269+ incompat_list .append (title )
270+
271+ if Script .runs_since_init < 1 and len (incompat_list ) >= 1 :
272+ m = "Distributed doesn't yet support:"
273+ for i in range (0 , len (incompat_list )):
274+ m += f" { incompat_list [i ]} "
275+ if i < len (incompat_list ) - 1 :
276+ m += ","
277+ logger .warning (m )
262278
263279 # encapsulating the request object within a txt2imgreq object is deprecated and no longer works
264280 # see test/basic_features/txt2img_test.py for an example
265281 payload = copy .copy (p .__dict__ )
266282 payload ['batch_size' ] = Script .world .default_batch_size ()
267283 payload ['scripts' ] = None
268- del payload ['script_args' ]
284+ try :
285+ del payload ['script_args' ]
286+ except KeyError :
287+ del payload ['script_args_value' ]
288+
269289
270290 payload ['alwayson_scripts' ] = {}
271291 for packed in packed_script_args :
@@ -279,7 +299,7 @@ def run(self, p, *args):
279299 # TODO api for some reason returns 200 even if something failed to be set.
280300 # for now we may have to make redundant GET requests to check if actually successful...
281301 # https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/8146
282- name = re .sub (r'\s?\[[^\ ]]*\ ]$' , '' , opts .data ["sd_model_checkpoint" ])
302+ name = re .sub (r'\s?\[[^]]*]$' , '' , opts .data ["sd_model_checkpoint" ])
283303 vae = opts .data ["sd_vae" ]
284304 option_payload = {
285305 "sd_model_checkpoint" : name ,
@@ -297,7 +317,9 @@ def run(self, p, *args):
297317 return
298318
299319 for job in Script .world .jobs :
300- payload_temp = copy .deepcopy (payload )
320+ payload_temp = copy .copy (payload )
321+ del payload_temp ['scripts_value' ]
322+ payload_temp = copy .deepcopy (payload_temp )
301323
302324 if job .worker .master :
303325 started_jobs .append (job )
@@ -311,14 +333,17 @@ def run(self, p, *args):
311333 payload_temp ['batch_size' ] = job .batch_size
312334 payload_temp ['subseed' ] += prior_images
313335 payload_temp ['seed' ] += prior_images if payload_temp ['subseed_strength' ] == 0 else 0
314- logger .debug (f"'{ job .worker .uuid } ' job's given starting seed is { payload_temp ['seed' ]} with { prior_images } coming before it" )
336+ logger .debug (
337+ f"'{ job .worker .label } ' job's given starting seed is "
338+ f"{ payload_temp ['seed' ]} with { prior_images } coming before it" )
315339
316340 if job .worker .loaded_model != name or job .worker .loaded_vae != vae :
317341 sync = True
318342 job .worker .loaded_model = name
319343 job .worker .loaded_vae = vae
320344
321- t = Thread (target = job .worker .request , args = (payload_temp , option_payload , sync , ), name = f"{ job .worker .uuid } _request" )
345+ t = Thread (target = job .worker .request , args = (payload_temp , option_payload , sync ,),
346+ name = f"{ job .worker .label } _request" )
322347
323348 t .start ()
324349 Script .worker_threads .append (t )
@@ -341,8 +366,26 @@ def run(self, p, *args):
341366 processed .infotexts = []
342367 processed .prompt = None
343368 else :
344- processed = processing .process_images (p , * args )
369+ processed = processing .process_images (p )
345370
346371 Script .add_to_gallery (processed , p )
347372 Script .runs_since_init += 1
348373 return processed
374+
375+ @staticmethod
376+ def signal_handler (sig , frame ):
377+ logger .debug ("handling interrupt signal" )
378+ # do cleanup
379+ Script .world .save_config ()
380+
381+ if sig == signal .SIGINT :
382+ if callable (old_sigint_handler ):
383+ old_sigint_handler (sig , frame )
384+ else :
385+ if callable (old_sigterm_handler ):
386+ old_sigterm_handler (sig , frame )
387+ else :
388+ sys .exit (0 )
389+
390+ signal .signal (signal .SIGINT , signal_handler )
391+ signal .signal (signal .SIGTERM , signal_handler )
0 commit comments