Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
98497006
Commit
98497006
authored
Apr 17, 2023
by
siutin
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
multi users support
parent
70ab21e6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
59 additions
and
28 deletions
+59
-28
modules/call_queue.py
modules/call_queue.py
+14
-9
modules/progress.py
modules/progress.py
+43
-17
modules/ui.py
modules/ui.py
+2
-2
No files found.
modules/call_queue.py
View file @
98497006
...
@@ -4,6 +4,7 @@ import threading
...
@@ -4,6 +4,7 @@ import threading
import
traceback
import
traceback
import
time
import
time
import
gradio
as
gr
from
modules
import
shared
,
progress
from
modules
import
shared
,
progress
queue_lock
=
threading
.
Lock
()
queue_lock
=
threading
.
Lock
()
...
@@ -20,41 +21,45 @@ def wrap_queued_call(func):
...
@@ -20,41 +21,45 @@ def wrap_queued_call(func):
def
wrap_gradio_gpu_call
(
func
,
extra_outputs
=
None
):
def
wrap_gradio_gpu_call
(
func
,
extra_outputs
=
None
):
def
f
(
*
args
,
**
kwargs
):
def
f
(
request
:
gr
.
Request
,
*
args
,
**
kwargs
):
user
=
request
.
username
# if the first argument is a string that says "task(...)", it is treated as a job id
# if the first argument is a string that says "task(...)", it is treated as a job id
if
len
(
args
)
>
0
and
type
(
args
[
0
])
==
str
and
args
[
0
][
0
:
5
]
==
"task("
and
args
[
0
][
-
1
]
==
")"
:
if
len
(
args
)
>
0
and
type
(
args
[
0
])
==
str
and
args
[
0
][
0
:
5
]
==
"task("
and
args
[
0
][
-
1
]
==
")"
:
id_task
=
args
[
0
]
id_task
=
args
[
0
]
progress
.
add_task_to_queue
(
id_task
)
progress
.
add_task_to_queue
(
user
,
id_task
)
else
:
else
:
id_task
=
None
id_task
=
None
with
queue_lock
:
with
queue_lock
:
shared
.
state
.
begin
()
shared
.
state
.
begin
()
progress
.
start_task
(
id_task
)
progress
.
start_task
(
user
,
id_task
)
try
:
try
:
res
=
func
(
*
args
,
**
kwargs
)
res
=
func
(
*
args
,
**
kwargs
)
finally
:
finally
:
progress
.
finish_task
(
id_task
)
progress
.
finish_task
(
user
,
id_task
)
progress
.
set_last_task_result
(
id_task
,
res
)
progress
.
set_last_task_result
(
user
,
id_task
,
res
)
shared
.
state
.
end
()
shared
.
state
.
end
()
return
res
return
res
return
wrap_gradio_call
(
f
,
extra_outputs
=
extra_outputs
,
add_stats
=
True
)
return
wrap_gradio_call
(
f
,
extra_outputs
=
extra_outputs
,
add_stats
=
True
,
add_request
=
True
)
def
wrap_gradio_call
(
func
,
extra_outputs
=
None
,
add_stats
=
False
):
def
wrap_gradio_call
(
func
,
extra_outputs
=
None
,
add_stats
=
False
,
add_request
=
False
):
def
f
(
*
args
,
extra_outputs_array
=
extra_outputs
,
**
kwargs
):
def
f
(
request
:
gr
.
Request
,
*
args
,
extra_outputs_array
=
extra_outputs
,
**
kwargs
):
run_memmon
=
shared
.
opts
.
memmon_poll_rate
>
0
and
not
shared
.
mem_mon
.
disabled
and
add_stats
run_memmon
=
shared
.
opts
.
memmon_poll_rate
>
0
and
not
shared
.
mem_mon
.
disabled
and
add_stats
if
run_memmon
:
if
run_memmon
:
shared
.
mem_mon
.
monitor
()
shared
.
mem_mon
.
monitor
()
t
=
time
.
perf_counter
()
t
=
time
.
perf_counter
()
try
:
try
:
res
=
list
(
func
(
*
args
,
**
kwargs
))
if
add_request
:
res
=
list
(
func
(
request
,
*
args
,
**
kwargs
))
else
:
res
=
list
(
func
(
*
args
,
**
kwargs
))
except
Exception
as
e
:
except
Exception
as
e
:
# When printing out our debug argument list, do not print out more than a MB of text
# When printing out our debug argument list, do not print out more than a MB of text
max_debug_str_len
=
131072
# (1024*1024)/8
max_debug_str_len
=
131072
# (1024*1024)/8
...
...
modules/progress.py
View file @
98497006
...
@@ -4,7 +4,9 @@ import time
...
@@ -4,7 +4,9 @@ import time
import
gradio
as
gr
import
gradio
as
gr
from
pydantic
import
BaseModel
,
Field
from
pydantic
import
BaseModel
,
Field
from
typing
import
List
from
typing
import
Optional
from
fastapi
import
Depends
,
Security
from
fastapi.security
import
APIKeyCookie
from
modules
import
call_queue
from
modules
import
call_queue
from
modules.shared
import
opts
from
modules.shared
import
opts
...
@@ -12,57 +14,71 @@ from modules.shared import opts
...
@@ -12,57 +14,71 @@ from modules.shared import opts
import
modules.shared
as
shared
import
modules.shared
as
shared
current_task_user
=
None
current_task
=
None
current_task
=
None
pending_tasks
=
{}
pending_tasks
=
{}
finished_tasks
=
[]
finished_tasks
=
[]
def
start_task
(
id_task
):
def
start_task
(
user
,
id_task
):
global
current_task
global
current_task
global
current_task_user
current_task_user
=
user
current_task
=
id_task
current_task
=
id_task
pending_tasks
.
pop
(
id_task
,
None
)
pending_tasks
.
pop
(
(
user
,
id_task
)
,
None
)
def
finish_task
(
id_task
):
def
finish_task
(
user
,
id_task
):
global
current_task
global
current_task
global
current_task_user
if
current_task
==
id_task
:
if
current_task
==
id_task
:
current_task
=
None
current_task
=
None
finished_tasks
.
append
(
id_task
)
if
current_task_user
==
user
:
current_task_user
=
None
finished_tasks
.
append
((
user
,
id_task
))
if
len
(
finished_tasks
)
>
16
:
if
len
(
finished_tasks
)
>
16
:
finished_tasks
.
pop
(
0
)
finished_tasks
.
pop
(
0
)
def
add_task_to_queue
(
id_job
):
def
add_task_to_queue
(
user
,
id_job
):
pending_tasks
[
id_job
]
=
time
.
time
()
pending_tasks
[
(
user
,
id_job
)
]
=
time
.
time
()
last_task_id
=
None
last_task_id
=
None
last_task_result
=
None
last_task_result
=
None
last_task_user
=
None
def
set_last_task_result
(
user
,
id_job
,
result
):
def
set_last_task_result
(
id_job
,
result
):
global
last_task_id
global
last_task_id
global
last_task_result
global
last_task_result
global
last_task_user
last_task_id
=
id_job
last_task_id
=
id_job
last_task_result
=
result
last_task_result
=
result
last_task_user
=
user
def
restore_progress_call
():
def
restore_progress_call
(
request
:
gr
.
Request
):
if
current_task
is
None
:
if
current_task
is
None
:
# image, generation_info, html_info, html_log
# image, generation_info, html_info, html_log
return
tuple
(
list
([
None
,
None
,
None
,
None
]))
return
tuple
(
list
([
None
,
None
,
None
,
None
]))
else
:
else
:
user
=
request
.
username
t_task
=
current_task
if
current_task_user
==
user
:
with
call_queue
.
queue_lock_condition
:
t_task
=
current_task
call_queue
.
queue_lock_condition
.
wait_for
(
lambda
:
t_task
==
last_task_id
)
with
call_queue
.
queue_lock_condition
:
call_queue
.
queue_lock_condition
.
wait_for
(
lambda
:
t_task
==
last_task_id
)
return
last_task_result
return
last_task_result
return
tuple
(
list
([
None
,
None
,
None
,
None
]))
class
CurrentTaskResponse
(
BaseModel
):
class
CurrentTaskResponse
(
BaseModel
):
current_task
:
str
=
Field
(
default
=
None
,
title
=
"Task ID"
,
description
=
"id of the current progress task"
)
current_task
:
str
=
Field
(
default
=
None
,
title
=
"Task ID"
,
description
=
"id of the current progress task"
)
...
@@ -87,6 +103,19 @@ def setup_progress_api(app):
...
@@ -87,6 +103,19 @@ def setup_progress_api(app):
return
app
.
add_api_route
(
"/internal/progress"
,
progressapi
,
methods
=
[
"POST"
],
response_model
=
ProgressResponse
)
return
app
.
add_api_route
(
"/internal/progress"
,
progressapi
,
methods
=
[
"POST"
],
response_model
=
ProgressResponse
)
def
setup_current_task_api
(
app
):
def
setup_current_task_api
(
app
):
def
get_current_user
(
token
:
Optional
[
str
]
=
Security
(
APIKeyCookie
(
name
=
"access-token"
,
auto_error
=
False
))):
return
None
if
token
is
None
else
app
.
tokens
.
get
(
token
)
def
current_task_api
(
current_user
:
str
=
Depends
(
get_current_user
)):
if
app
.
auth
is
None
or
current_task_user
==
current_user
:
current_user_task
=
current_task
else
:
current_user_task
=
None
return
CurrentTaskResponse
(
current_task
=
current_user_task
)
return
app
.
add_api_route
(
"/internal/current_task"
,
current_task_api
,
methods
=
[
"GET"
],
response_model
=
CurrentTaskResponse
)
return
app
.
add_api_route
(
"/internal/current_task"
,
current_task_api
,
methods
=
[
"GET"
],
response_model
=
CurrentTaskResponse
)
def
progressapi
(
req
:
ProgressRequest
):
def
progressapi
(
req
:
ProgressRequest
):
...
@@ -127,7 +156,4 @@ def progressapi(req: ProgressRequest):
...
@@ -127,7 +156,4 @@ def progressapi(req: ProgressRequest):
else
:
else
:
live_preview
=
None
live_preview
=
None
return
ProgressResponse
(
active
=
active
,
queued
=
queued
,
completed
=
completed
,
progress
=
progress
,
eta
=
eta
,
live_preview
=
live_preview
,
id_live_preview
=
id_live_preview
,
textinfo
=
shared
.
state
.
textinfo
)
return
ProgressResponse
(
active
=
active
,
queued
=
queued
,
completed
=
completed
,
progress
=
progress
,
eta
=
eta
,
live_preview
=
live_preview
,
id_live_preview
=
id_live_preview
,
textinfo
=
shared
.
state
.
textinfo
)
\ No newline at end of file
def
current_task_api
():
return
CurrentTaskResponse
(
current_task
=
current_task
)
\ No newline at end of file
modules/ui.py
View file @
98497006
...
@@ -582,7 +582,7 @@ def create_ui():
...
@@ -582,7 +582,7 @@ def create_ui():
res_switch_btn
.
click
(
lambda
w
,
h
:
(
h
,
w
),
inputs
=
[
width
,
height
],
outputs
=
[
width
,
height
],
show_progress
=
False
)
res_switch_btn
.
click
(
lambda
w
,
h
:
(
h
,
w
),
inputs
=
[
width
,
height
],
outputs
=
[
width
,
height
],
show_progress
=
False
)
restore_progress_button
.
click
(
restore_progress_button
.
click
(
fn
=
lambda
:
restore_progress_call
()
,
fn
=
restore_progress_call
,
_js
=
"() => restoreProgress('txt2img')"
,
_js
=
"() => restoreProgress('txt2img')"
,
inputs
=
[],
inputs
=
[],
outputs
=
[
outputs
=
[
...
@@ -914,7 +914,7 @@ def create_ui():
...
@@ -914,7 +914,7 @@ def create_ui():
res_switch_btn
.
click
(
lambda
w
,
h
:
(
h
,
w
),
inputs
=
[
width
,
height
],
outputs
=
[
width
,
height
],
show_progress
=
False
)
res_switch_btn
.
click
(
lambda
w
,
h
:
(
h
,
w
),
inputs
=
[
width
,
height
],
outputs
=
[
width
,
height
],
show_progress
=
False
)
restore_progress_button
.
click
(
restore_progress_button
.
click
(
fn
=
lambda
:
restore_progress_call
()
,
fn
=
restore_progress_call
,
_js
=
"() => restoreProgress('img2img')"
,
_js
=
"() => restoreProgress('img2img')"
,
inputs
=
[],
inputs
=
[],
outputs
=
[
outputs
=
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment