Skip to content

Commit 5b7d86d

Browse files
Merge pull request #14314 from gayshub/master
Add allow specify the task id and get the location of task in the queue of pending task
2 parents 93eae69 + 6d7e57b commit 5b7d86d

File tree

4 files changed

+40
-4
lines changed

4 files changed

+40
-4
lines changed

modules/api/api.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import piexif
3232
import piexif.helper
3333
from contextlib import closing
34-
34+
from modules.progress import create_task_id, add_task_to_queue, start_task, finish_task, current_task
3535

3636
def script_name_to_index(name, scripts):
3737
try:
@@ -336,6 +336,9 @@ def init_script_args(self, request, default_script_args, selectable_scripts, sel
336336
return script_args
337337

338338
def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
339+
task_id = create_task_id("text2img")
340+
if txt2imgreq.force_task_id is None:
341+
task_id = txt2imgreq.force_task_id
339342
script_runner = scripts.scripts_txt2img
340343
if not script_runner.scripts:
341344
script_runner.initialize_scripts(False)
@@ -362,6 +365,8 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
362365
send_images = args.pop('send_images', True)
363366
args.pop('save_images', None)
364367

368+
add_task_to_queue(task_id)
369+
365370
with self.queue_lock:
366371
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
367372
p.is_api = True
@@ -371,12 +376,14 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
371376

372377
try:
373378
shared.state.begin(job="scripts_txt2img")
379+
start_task(task_id)
374380
if selectable_scripts is not None:
375381
p.script_args = script_args
376382
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
377383
else:
378384
p.script_args = tuple(script_args) # Need to pass args as tuple here
379385
processed = process_images(p)
386+
finish_task(task_id)
380387
finally:
381388
shared.state.end()
382389
shared.total_tqdm.clear()
@@ -386,6 +393,10 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
386393
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
387394

388395
def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
396+
task_id = create_task_id("img2img")
397+
if img2imgreq.force_task_id is None:
398+
task_id = img2imgreq.force_task_id
399+
389400
init_images = img2imgreq.init_images
390401
if init_images is None:
391402
raise HTTPException(status_code=404, detail="Init image not found")
@@ -422,6 +433,8 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
422433
send_images = args.pop('send_images', True)
423434
args.pop('save_images', None)
424435

436+
add_task_to_queue(task_id)
437+
425438
with self.queue_lock:
426439
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
427440
p.init_images = [decode_base64_to_image(x) for x in init_images]
@@ -432,12 +445,14 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
432445

433446
try:
434447
shared.state.begin(job="scripts_img2img")
448+
start_task(task_id)
435449
if selectable_scripts is not None:
436450
p.script_args = script_args
437451
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
438452
else:
439453
p.script_args = tuple(script_args) # Need to pass args as tuple here
440454
processed = process_images(p)
455+
finish_task(task_id)
441456
finally:
442457
shared.state.end()
443458
shared.total_tqdm.clear()
@@ -511,7 +526,7 @@ def progressapi(self, req: models.ProgressRequest = Depends()):
511526
if shared.state.current_image and not req.skip_current_image:
512527
current_image = encode_pil_to_base64(shared.state.current_image)
513528

514-
return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
529+
return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo, current_task=current_task)
515530

516531
def interrogateapi(self, interrogatereq: models.InterrogateRequest):
517532
image_b64 = interrogatereq.image

modules/api/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def generate_model(self):
107107
{"key": "send_images", "type": bool, "default": True},
108108
{"key": "save_images", "type": bool, "default": False},
109109
{"key": "alwayson_scripts", "type": dict, "default": {}},
110+
{"key": "force_task_id", "type": str, "default": None},
110111
]
111112
).generate_model()
112113

@@ -124,6 +125,7 @@ def generate_model(self):
124125
{"key": "send_images", "type": bool, "default": True},
125126
{"key": "save_images", "type": bool, "default": False},
126127
{"key": "alwayson_scripts", "type": dict, "default": {}},
128+
{"key": "force_task_id", "type": str, "default": None},
127129
]
128130
).generate_model()
129131

modules/processing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,6 +1060,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
10601060
hr_sampler_name: str = None
10611061
hr_prompt: str = ''
10621062
hr_negative_prompt: str = ''
1063+
force_task_id: str = None
10631064

10641065
cached_hr_uc = [None, None]
10651066
cached_hr_c = [None, None]
@@ -1393,6 +1394,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
13931394
inpainting_mask_invert: int = 0
13941395
initial_noise_multiplier: float = None
13951396
latent_mask: Image = None
1397+
force_task_id: str = None
13961398

13971399
image_mask: Any = field(default=None, init=False)
13981400

modules/progress.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@
88
from modules.shared import opts
99

1010
import modules.shared as shared
11-
11+
from collections import OrderedDict
12+
import string
13+
import random
14+
from typing import List
1215

1316
current_task = None
14-
pending_tasks = {}
17+
pending_tasks = OrderedDict()
1518
finished_tasks = []
1619
recorded_results = []
1720
recorded_results_limit = 2
@@ -34,6 +37,11 @@ def finish_task(id_task):
3437
if len(finished_tasks) > 16:
3538
finished_tasks.pop(0)
3639

40+
def create_task_id(task_type):
41+
N = 7
42+
res = ''.join(random.choices(string.ascii_uppercase +
43+
string.digits, k=N))
44+
return f"task({task_type}-{res})"
3745

3846
def record_results(id_task, res):
3947
recorded_results.append((id_task, res))
@@ -44,6 +52,9 @@ def record_results(id_task, res):
4452
def add_task_to_queue(id_job):
4553
pending_tasks[id_job] = time.time()
4654

55+
class PendingTasksResponse(BaseModel):
56+
size: int = Field(title="Pending task size")
57+
tasks: List[str] = Field(title="Pending task ids")
4758

4859
class ProgressRequest(BaseModel):
4960
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
@@ -63,8 +74,14 @@ class ProgressResponse(BaseModel):
6374

6475

6576
def setup_progress_api(app):
77+
app.add_api_route("/internal/pendingTasks", get_pending_tasks, methods=["GET"])
6678
return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
6779

80+
def get_pending_tasks():
81+
pending_tasks_ids = list(pending_tasks)
82+
pending_len = len(pending_tasks_ids)
83+
return PendingTasksResponse(size=pending_len, tasks=pending_tasks_ids)
84+
6885

6986
def progressapi(req: ProgressRequest):
7087
active = req.id_task == current_task

0 commit comments

Comments
 (0)