Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
fddb4883
Commit
fddb4883
authored
Oct 26, 2022
by
evshiron
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
prototype progress api
parent
99d728b5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
88 additions
and
14 deletions
+88
-14
modules/api/api.py
modules/api/api.py
+75
-14
modules/shared.py
modules/shared.py
+13
-0
No files found.
modules/api/api.py
View file @
fddb4883
import
time
from
modules.api.models
import
StableDiffusionTxt2ImgProcessingAPI
,
StableDiffusionImg2ImgProcessingAPI
from
modules.processing
import
StableDiffusionProcessingTxt2Img
,
StableDiffusionProcessingImg2Img
,
process_images
from
modules.sd_samplers
import
all_samplers
from
modules.extras
import
run_pnginfo
import
modules.shared
as
shared
from
modules
import
devices
import
uvicorn
from
fastapi
import
Body
,
APIRouter
,
HTTPException
from
fastapi.responses
import
JSONResponse
...
...
@@ -25,6 +28,37 @@ class ImageToImageResponse(BaseModel):
parameters
:
Json
info
:
Json
class
ProgressResponse
(
BaseModel
):
progress
:
float
eta_relative
:
float
state
:
Json
# copy from wrap_gradio_gpu_call of webui.py
# because queue lock will be acquired in api handlers
# and time start needs to be set
# the function has been modified into two parts
def
before_gpu_call
():
devices
.
torch_gc
()
shared
.
state
.
sampling_step
=
0
shared
.
state
.
job_count
=
-
1
shared
.
state
.
job_no
=
0
shared
.
state
.
job_timestamp
=
shared
.
state
.
get_job_timestamp
()
shared
.
state
.
current_latent
=
None
shared
.
state
.
current_image
=
None
shared
.
state
.
current_image_sampling_step
=
0
shared
.
state
.
skipped
=
False
shared
.
state
.
interrupted
=
False
shared
.
state
.
textinfo
=
None
shared
.
state
.
time_start
=
time
.
time
()
def
after_gpu_call
():
shared
.
state
.
job
=
""
shared
.
state
.
job_count
=
0
devices
.
torch_gc
()
class
Api
:
def
__init__
(
self
,
app
,
queue_lock
):
...
...
@@ -33,6 +67,7 @@ class Api:
self
.
queue_lock
=
queue_lock
self
.
app
.
add_api_route
(
"/sdapi/v1/txt2img"
,
self
.
text2imgapi
,
methods
=
[
"POST"
])
self
.
app
.
add_api_route
(
"/sdapi/v1/img2img"
,
self
.
img2imgapi
,
methods
=
[
"POST"
])
self
.
app
.
add_api_route
(
"/sdapi/v1/progress"
,
self
.
progressapi
,
methods
=
[
"GET"
])
def
__base64_to_image
(
self
,
base64_string
):
# if has a comma, deal with prefix
...
...
@@ -44,12 +79,12 @@ class Api:
def
text2imgapi
(
self
,
txt2imgreq
:
StableDiffusionTxt2ImgProcessingAPI
):
sampler_index
=
sampler_to_index
(
txt2imgreq
.
sampler_index
)
if
sampler_index
is
None
:
raise
HTTPException
(
status_code
=
404
,
detail
=
"Sampler not found"
)
raise
HTTPException
(
status_code
=
404
,
detail
=
"Sampler not found"
)
populate
=
txt2imgreq
.
copy
(
update
=
{
# Override __init__ params
"sd_model"
:
shared
.
sd_model
,
"sd_model"
:
shared
.
sd_model
,
"sampler_index"
:
sampler_index
[
0
],
"do_not_save_samples"
:
True
,
"do_not_save_grid"
:
True
...
...
@@ -57,9 +92,11 @@ class Api:
)
p
=
StableDiffusionProcessingTxt2Img
(
**
vars
(
populate
))
# Override object param
before_gpu_call
()
with
self
.
queue_lock
:
processed
=
process_images
(
p
)
after_gpu_call
()
b64images
=
[]
for
i
in
processed
.
images
:
buffer
=
io
.
BytesIO
()
...
...
@@ -67,30 +104,30 @@ class Api:
b64images
.
append
(
base64
.
b64encode
(
buffer
.
getvalue
()))
return
TextToImageResponse
(
images
=
b64images
,
parameters
=
json
.
dumps
(
vars
(
txt2imgreq
)),
info
=
processed
.
js
())
def
img2imgapi
(
self
,
img2imgreq
:
StableDiffusionImg2ImgProcessingAPI
):
sampler_index
=
sampler_to_index
(
img2imgreq
.
sampler_index
)
if
sampler_index
is
None
:
raise
HTTPException
(
status_code
=
404
,
detail
=
"Sampler not found"
)
raise
HTTPException
(
status_code
=
404
,
detail
=
"Sampler not found"
)
init_images
=
img2imgreq
.
init_images
if
init_images
is
None
:
raise
HTTPException
(
status_code
=
404
,
detail
=
"Init image not found"
)
raise
HTTPException
(
status_code
=
404
,
detail
=
"Init image not found"
)
mask
=
img2imgreq
.
mask
if
mask
:
mask
=
self
.
__base64_to_image
(
mask
)
populate
=
img2imgreq
.
copy
(
update
=
{
# Override __init__ params
"sd_model"
:
shared
.
sd_model
,
"sd_model"
:
shared
.
sd_model
,
"sampler_index"
:
sampler_index
[
0
],
"do_not_save_samples"
:
True
,
"do_not_save_grid"
:
True
,
"do_not_save_grid"
:
True
,
"mask"
:
mask
}
)
...
...
@@ -103,9 +140,11 @@ class Api:
p
.
init_images
=
imgs
# Override object param
before_gpu_call
()
with
self
.
queue_lock
:
processed
=
process_images
(
p
)
after_gpu_call
()
b64images
=
[]
for
i
in
processed
.
images
:
buffer
=
io
.
BytesIO
()
...
...
@@ -118,6 +157,28 @@ class Api:
return
ImageToImageResponse
(
images
=
b64images
,
parameters
=
json
.
dumps
(
vars
(
img2imgreq
)),
info
=
processed
.
js
())
def
progressapi
(
self
):
# copy from check_progress_call of ui.py
if
shared
.
state
.
job_count
==
0
:
return
ProgressResponse
(
progress
=
0
,
eta_relative
=
0
,
state
=
shared
.
state
.
js
())
# avoid dividing zero
progress
=
0.01
if
shared
.
state
.
job_count
>
0
:
progress
+=
shared
.
state
.
job_no
/
shared
.
state
.
job_count
if
shared
.
state
.
sampling_steps
>
0
:
progress
+=
1
/
shared
.
state
.
job_count
*
shared
.
state
.
sampling_step
/
shared
.
state
.
sampling_steps
time_since_start
=
time
.
time
()
-
shared
.
state
.
time_start
eta
=
(
time_since_start
/
progress
)
eta_relative
=
eta
-
time_since_start
progress
=
min
(
progress
,
1
)
return
ProgressResponse
(
progress
=
progress
,
eta_relative
=
eta_relative
,
state
=
shared
.
state
.
js
())
def
extrasapi
(
self
):
raise
NotImplementedError
...
...
modules/shared.py
View file @
fddb4883
...
...
@@ -146,6 +146,19 @@ class State:
def
get_job_timestamp
(
self
):
return
datetime
.
datetime
.
now
()
.
strftime
(
"
%
Y
%
m
%
d
%
H
%
M
%
S"
)
# shouldn't this return job_timestamp?
def
js
(
self
):
obj
=
{
"skipped"
:
self
.
skipped
,
"interrupted"
:
self
.
skipped
,
"job"
:
self
.
job
,
"job_count"
:
self
.
job_count
,
"job_no"
:
self
.
job_no
,
"sampling_step"
:
self
.
sampling_step
,
"sampling_steps"
:
self
.
sampling_steps
,
}
return
json
.
dumps
(
obj
)
state
=
State
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment