Commit 98497006 authored by siutin's avatar siutin

multi users support

parent 70ab21e6
......@@ -4,6 +4,7 @@ import threading
import traceback
import time
import gradio as gr
from modules import shared, progress
queue_lock = threading.Lock()
......@@ -20,41 +21,45 @@ def wrap_queued_call(func):
def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs):
def f(request: gr.Request, *args, **kwargs):
user = request.username
# if the first argument is a string that says "task(...)", it is treated as a job id
if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
id_task = args[0]
progress.add_task_to_queue(id_task)
progress.add_task_to_queue(user, id_task)
else:
id_task = None
with queue_lock:
shared.state.begin()
progress.start_task(id_task)
progress.start_task(user, id_task)
try:
res = func(*args, **kwargs)
finally:
progress.finish_task(id_task)
progress.set_last_task_result(id_task, res)
progress.finish_task(user, id_task)
progress.set_last_task_result(user, id_task, res)
shared.state.end()
return res
return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True, add_request=True)
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
def wrap_gradio_call(func, extra_outputs=None, add_stats=False, add_request=False):
def f(request: gr.Request, *args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
if run_memmon:
shared.mem_mon.monitor()
t = time.perf_counter()
try:
res = list(func(*args, **kwargs))
if add_request:
res = list(func(request, *args, **kwargs))
else:
res = list(func(*args, **kwargs))
except Exception as e:
# When printing out our debug argument list, do not print out more than a MB of text
max_debug_str_len = 131072 # (1024*1024)/8
......
......@@ -4,7 +4,9 @@ import time
import gradio as gr
from pydantic import BaseModel, Field
from typing import List
from typing import Optional
from fastapi import Depends, Security
from fastapi.security import APIKeyCookie
from modules import call_queue
from modules.shared import opts
......@@ -12,57 +14,71 @@ from modules.shared import opts
import modules.shared as shared
current_task_user = None
current_task = None
pending_tasks = {}
finished_tasks = []
def start_task(id_task):
def start_task(user, id_task):
global current_task
global current_task_user
current_task_user = user
current_task = id_task
pending_tasks.pop(id_task, None)
pending_tasks.pop((user, id_task), None)
def finish_task(id_task):
def finish_task(user, id_task):
global current_task
global current_task_user
if current_task == id_task:
current_task = None
finished_tasks.append(id_task)
if current_task_user == user:
current_task_user = None
finished_tasks.append((user, id_task))
if len(finished_tasks) > 16:
finished_tasks.pop(0)
def add_task_to_queue(id_job):
pending_tasks[id_job] = time.time()
def add_task_to_queue(user, id_job):
pending_tasks[(user, id_job)] = time.time()
last_task_id = None
last_task_result = None
last_task_user = None
def set_last_task_result(user, id_job, result):
def set_last_task_result(id_job, result):
global last_task_id
global last_task_result
global last_task_user
last_task_id = id_job
last_task_result = result
last_task_user = user
def restore_progress_call():
def restore_progress_call(request: gr.Request):
if current_task is None:
# image, generation_info, html_info, html_log
return tuple(list([None, None, None, None]))
else:
user = request.username
t_task = current_task
with call_queue.queue_lock_condition:
call_queue.queue_lock_condition.wait_for(lambda: t_task == last_task_id)
if current_task_user == user:
t_task = current_task
with call_queue.queue_lock_condition:
call_queue.queue_lock_condition.wait_for(lambda: t_task == last_task_id)
return last_task_result
return last_task_result
return tuple(list([None, None, None, None]))
class CurrentTaskResponse(BaseModel):
current_task: str = Field(default=None, title="Task ID", description="id of the current progress task")
......@@ -87,6 +103,19 @@ def setup_progress_api(app):
return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
def setup_current_task_api(app):
def get_current_user(token: Optional[str] = Security(APIKeyCookie(name="access-token", auto_error=False))):
return None if token is None else app.tokens.get(token)
def current_task_api(current_user: str = Depends(get_current_user)):
if app.auth is None or current_task_user == current_user:
current_user_task = current_task
else:
current_user_task = None
return CurrentTaskResponse(current_task=current_user_task)
return app.add_api_route("/internal/current_task", current_task_api, methods=["GET"], response_model=CurrentTaskResponse)
def progressapi(req: ProgressRequest):
......@@ -127,7 +156,4 @@ def progressapi(req: ProgressRequest):
else:
live_preview = None
return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
def current_task_api():
return CurrentTaskResponse(current_task=current_task)
\ No newline at end of file
return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
\ No newline at end of file
......@@ -582,7 +582,7 @@ def create_ui():
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
restore_progress_button.click(
fn=lambda: restore_progress_call(),
fn=restore_progress_call,
_js="() => restoreProgress('txt2img')",
inputs=[],
outputs=[
......@@ -914,7 +914,7 @@ def create_ui():
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
restore_progress_button.click(
fn=lambda: restore_progress_call(),
fn=restore_progress_call,
_js="() => restoreProgress('img2img')",
inputs=[],
outputs=[
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment