Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
ea8aa170
Commit
ea8aa170
authored
Oct 15, 2022
by
AUTOMATIC1111
Committed by
GitHub
Oct 15, 2022
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into master
parents
4d19f3b7
a13af34b
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
219 additions
and
139 deletions
+219
-139
javascript/images_history.js
javascript/images_history.js
+12
-4
javascript/progressbar.js
javascript/progressbar.js
+35
-0
javascript/ui.js
javascript/ui.js
+0
-10
launch.py
launch.py
+2
-1
modules/hypernetworks/hypernetwork.py
modules/hypernetworks/hypernetwork.py
+24
-9
modules/images_history.py
modules/images_history.py
+100
-84
modules/processing.py
modules/processing.py
+1
-1
modules/textual_inversion/dataset.py
modules/textual_inversion/dataset.py
+20
-13
modules/textual_inversion/textual_inversion.py
modules/textual_inversion/textual_inversion.py
+8
-9
modules/ui.py
modules/ui.py
+15
-6
script.js
script.js
+2
-2
No files found.
javascript/images_history.js
View file @
ea8aa170
...
@@ -163,10 +163,15 @@ function images_history_init(){
...
@@ -163,10 +163,15 @@ function images_history_init(){
for
(
var
i
in
images_history_tab_list
){
for
(
var
i
in
images_history_tab_list
){
var
tabname
=
images_history_tab_list
[
i
]
var
tabname
=
images_history_tab_list
[
i
]
tab_btns
[
i
].
setAttribute
(
"
tabname
"
,
tabname
);
tab_btns
[
i
].
setAttribute
(
"
tabname
"
,
tabname
);
tab_btns
[
i
].
addEventListener
(
'
click
'
,
images_history_click_tab
);
// this refreshes history upon tab switch
// until the history is known to work well, which is not the case now, we do not do this at startup
//tab_btns[i].addEventListener('click', images_history_click_tab);
}
}
tabs_box
.
classList
.
add
(
images_history_tab_list
[
0
]);
tabs_box
.
classList
.
add
(
images_history_tab_list
[
0
]);
load_txt2img_button
.
click
();
// same as above, at page load
//load_txt2img_button.click();
}
else
{
}
else
{
setTimeout
(
images_history_init
,
500
);
setTimeout
(
images_history_init
,
500
);
}
}
...
@@ -182,12 +187,15 @@ document.addEventListener("DOMContentLoaded", function() {
...
@@ -182,12 +187,15 @@ document.addEventListener("DOMContentLoaded", function() {
buttons
.
forEach
(
function
(
bnt
){
buttons
.
forEach
(
function
(
bnt
){
bnt
.
addEventListener
(
'
click
'
,
images_history_click_image
,
true
);
bnt
.
addEventListener
(
'
click
'
,
images_history_click_image
,
true
);
});
});
// same as load_txt2img_button.click() above
/*
var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg");
var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg");
if (cls_btn){
if (cls_btn){
cls_btn.addEventListener('click', function(){
cls_btn.addEventListener('click', function(){
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
}, false);
}, false);
}
}
*/
}
}
});
});
...
...
javascript/progressbar.js
View file @
ea8aa170
// code related to showing and updating progressbar shown as the image is being made
// code related to showing and updating progressbar shown as the image is being made
global_progressbars
=
{}
global_progressbars
=
{}
galleries
=
{}
galleryObservers
=
{}
function
check_progressbar
(
id_part
,
id_progressbar
,
id_progressbar_span
,
id_skip
,
id_interrupt
,
id_preview
,
id_gallery
){
function
check_progressbar
(
id_part
,
id_progressbar
,
id_progressbar_span
,
id_skip
,
id_interrupt
,
id_preview
,
id_gallery
){
var
progressbar
=
gradioApp
().
getElementById
(
id_progressbar
)
var
progressbar
=
gradioApp
().
getElementById
(
id_progressbar
)
...
@@ -31,13 +33,24 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
...
@@ -31,13 +33,24 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
preview
.
style
.
width
=
gallery
.
clientWidth
+
"
px
"
preview
.
style
.
width
=
gallery
.
clientWidth
+
"
px
"
preview
.
style
.
height
=
gallery
.
clientHeight
+
"
px
"
preview
.
style
.
height
=
gallery
.
clientHeight
+
"
px
"
//only watch gallery if there is a generation process going on
check_gallery
(
id_gallery
);
var
progressDiv
=
gradioApp
().
querySelectorAll
(
'
#
'
+
id_progressbar_span
).
length
>
0
;
var
progressDiv
=
gradioApp
().
querySelectorAll
(
'
#
'
+
id_progressbar_span
).
length
>
0
;
if
(
!
progressDiv
){
if
(
!
progressDiv
){
if
(
skip
)
{
if
(
skip
)
{
skip
.
style
.
display
=
"
none
"
skip
.
style
.
display
=
"
none
"
}
}
interrupt
.
style
.
display
=
"
none
"
interrupt
.
style
.
display
=
"
none
"
//disconnect observer once generation finished, so user can close selected image if they want
if
(
galleryObservers
[
id_gallery
])
{
galleryObservers
[
id_gallery
].
disconnect
();
galleries
[
id_gallery
]
=
null
;
}
}
}
}
}
window
.
setTimeout
(
function
()
{
requestMoreProgress
(
id_part
,
id_progressbar_span
,
id_skip
,
id_interrupt
)
},
500
)
window
.
setTimeout
(
function
()
{
requestMoreProgress
(
id_part
,
id_progressbar_span
,
id_skip
,
id_interrupt
)
},
500
)
...
@@ -46,6 +59,28 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
...
@@ -46,6 +59,28 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
}
}
}
}
function
check_gallery
(
id_gallery
){
let
gallery
=
gradioApp
().
getElementById
(
id_gallery
)
// if gallery has no change, no need to setting up observer again.
if
(
gallery
&&
galleries
[
id_gallery
]
!==
gallery
){
galleries
[
id_gallery
]
=
gallery
;
if
(
galleryObservers
[
id_gallery
]){
galleryObservers
[
id_gallery
].
disconnect
();
}
let
prevSelectedIndex
=
selected_gallery_index
();
galleryObservers
[
id_gallery
]
=
new
MutationObserver
(
function
(){
let
galleryButtons
=
gradioApp
().
querySelectorAll
(
'
#
'
+
id_gallery
+
'
.gallery-item
'
)
let
galleryBtnSelected
=
gradioApp
().
querySelector
(
'
#
'
+
id_gallery
+
'
.gallery-item.
\\
!ring-2
'
)
if
(
prevSelectedIndex
!==
-
1
&&
galleryButtons
.
length
>
prevSelectedIndex
&&
!
galleryBtnSelected
)
{
//automatically re-open previously selected index (if exists)
galleryButtons
[
prevSelectedIndex
].
click
();
showGalleryImage
();
}
})
galleryObservers
[
id_gallery
].
observe
(
gallery
,
{
childList
:
true
,
subtree
:
false
})
}
}
onUiUpdate
(
function
(){
onUiUpdate
(
function
(){
check_progressbar
(
'
txt2img
'
,
'
txt2img_progressbar
'
,
'
txt2img_progress_span
'
,
'
txt2img_skip
'
,
'
txt2img_interrupt
'
,
'
txt2img_preview
'
,
'
txt2img_gallery
'
)
check_progressbar
(
'
txt2img
'
,
'
txt2img_progressbar
'
,
'
txt2img_progress_span
'
,
'
txt2img_skip
'
,
'
txt2img_interrupt
'
,
'
txt2img_preview
'
,
'
txt2img_gallery
'
)
check_progressbar
(
'
img2img
'
,
'
img2img_progressbar
'
,
'
img2img_progress_span
'
,
'
img2img_skip
'
,
'
img2img_interrupt
'
,
'
img2img_preview
'
,
'
img2img_gallery
'
)
check_progressbar
(
'
img2img
'
,
'
img2img_progressbar
'
,
'
img2img_progress_span
'
,
'
img2img_skip
'
,
'
img2img_interrupt
'
,
'
img2img_preview
'
,
'
img2img_gallery
'
)
...
...
javascript/ui.js
View file @
ea8aa170
...
@@ -187,12 +187,10 @@ onUiUpdate(function(){
...
@@ -187,12 +187,10 @@ onUiUpdate(function(){
if
(
!
txt2img_textarea
)
{
if
(
!
txt2img_textarea
)
{
txt2img_textarea
=
gradioApp
().
querySelector
(
"
#txt2img_prompt > label > textarea
"
);
txt2img_textarea
=
gradioApp
().
querySelector
(
"
#txt2img_prompt > label > textarea
"
);
txt2img_textarea
?.
addEventListener
(
"
input
"
,
()
=>
update_token_counter
(
"
txt2img_token_button
"
));
txt2img_textarea
?.
addEventListener
(
"
input
"
,
()
=>
update_token_counter
(
"
txt2img_token_button
"
));
txt2img_textarea
?.
addEventListener
(
"
keyup
"
,
(
event
)
=>
submit_prompt
(
event
,
"
txt2img_generate
"
));
}
}
if
(
!
img2img_textarea
)
{
if
(
!
img2img_textarea
)
{
img2img_textarea
=
gradioApp
().
querySelector
(
"
#img2img_prompt > label > textarea
"
);
img2img_textarea
=
gradioApp
().
querySelector
(
"
#img2img_prompt > label > textarea
"
);
img2img_textarea
?.
addEventListener
(
"
input
"
,
()
=>
update_token_counter
(
"
img2img_token_button
"
));
img2img_textarea
?.
addEventListener
(
"
input
"
,
()
=>
update_token_counter
(
"
img2img_token_button
"
));
img2img_textarea
?.
addEventListener
(
"
keyup
"
,
(
event
)
=>
submit_prompt
(
event
,
"
img2img_generate
"
));
}
}
})
})
...
@@ -220,14 +218,6 @@ function update_token_counter(button_id) {
...
@@ -220,14 +218,6 @@ function update_token_counter(button_id) {
token_timeout
=
setTimeout
(()
=>
gradioApp
().
getElementById
(
button_id
)?.
click
(),
wait_time
);
token_timeout
=
setTimeout
(()
=>
gradioApp
().
getElementById
(
button_id
)?.
click
(),
wait_time
);
}
}
function
submit_prompt
(
event
,
generate_button_id
)
{
if
(
event
.
altKey
&&
event
.
keyCode
===
13
)
{
event
.
preventDefault
();
gradioApp
().
getElementById
(
generate_button_id
).
click
();
return
;
}
}
function
restart_reload
(){
function
restart_reload
(){
document
.
body
.
innerHTML
=
'
<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>
'
;
document
.
body
.
innerHTML
=
'
<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>
'
;
setTimeout
(
function
(){
location
.
reload
()},
2000
)
setTimeout
(
function
(){
location
.
reload
()},
2000
)
...
...
launch.py
View file @
ea8aa170
...
@@ -9,6 +9,7 @@ import platform
...
@@ -9,6 +9,7 @@ import platform
dir_repos
=
"repositories"
dir_repos
=
"repositories"
python
=
sys
.
executable
python
=
sys
.
executable
git
=
os
.
environ
.
get
(
'GIT'
,
"git"
)
git
=
os
.
environ
.
get
(
'GIT'
,
"git"
)
index_url
=
os
.
environ
.
get
(
'INDEX_URL'
,
""
)
def
extract_arg
(
args
,
name
):
def
extract_arg
(
args
,
name
):
...
@@ -57,7 +58,7 @@ def run_python(code, desc=None, errdesc=None):
...
@@ -57,7 +58,7 @@ def run_python(code, desc=None, errdesc=None):
def
run_pip
(
args
,
desc
=
None
):
def
run_pip
(
args
,
desc
=
None
):
return
run
(
f
'"{python}" -m pip {args} --prefer-binary'
,
desc
=
f
"Installing {desc}"
,
errdesc
=
f
"Couldn't install {desc}"
)
return
run
(
f
'"{python}" -m pip {args} --prefer-binary
{f'
--
index
-
url
{
index_url
}
' if index_url!='' else ''}
'
,
desc
=
f
"Installing {desc}"
,
errdesc
=
f
"Couldn't install {desc}"
)
def
check_run_python
(
code
):
def
check_run_python
(
code
):
...
...
modules/hypernetworks/hypernetwork.py
View file @
ea8aa170
...
@@ -182,7 +182,21 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
...
@@ -182,7 +182,21 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
return
self
.
to_out
(
out
)
return
self
.
to_out
(
out
)
def
train_hypernetwork
(
hypernetwork_name
,
learn_rate
,
data_root
,
log_directory
,
steps
,
create_image_every
,
save_hypernetwork_every
,
template_file
,
preview_from_txt2img
,
preview_prompt
,
preview_negative_prompt
,
preview_steps
,
preview_sampler_index
,
preview_cfg_scale
,
preview_seed
,
preview_width
,
preview_height
):
def
stack_conds
(
conds
):
if
len
(
conds
)
==
1
:
return
torch
.
stack
(
conds
)
# same as in reconstruct_multicond_batch
token_count
=
max
([
x
.
shape
[
0
]
for
x
in
conds
])
for
i
in
range
(
len
(
conds
)):
if
conds
[
i
]
.
shape
[
0
]
!=
token_count
:
last_vector
=
conds
[
i
][
-
1
:]
last_vector_repeated
=
last_vector
.
repeat
([
token_count
-
conds
[
i
]
.
shape
[
0
],
1
])
conds
[
i
]
=
torch
.
vstack
([
conds
[
i
],
last_vector_repeated
])
return
torch
.
stack
(
conds
)
def
train_hypernetwork
(
hypernetwork_name
,
learn_rate
,
batch_size
,
data_root
,
log_directory
,
steps
,
create_image_every
,
save_hypernetwork_every
,
template_file
,
preview_from_txt2img
,
preview_prompt
,
preview_negative_prompt
,
preview_steps
,
preview_sampler_index
,
preview_cfg_scale
,
preview_seed
,
preview_width
,
preview_height
):
assert
hypernetwork_name
,
'hypernetwork not selected'
assert
hypernetwork_name
,
'hypernetwork not selected'
path
=
shared
.
hypernetworks
.
get
(
hypernetwork_name
,
None
)
path
=
shared
.
hypernetworks
.
get
(
hypernetwork_name
,
None
)
...
@@ -211,7 +225,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
...
@@ -211,7 +225,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
shared
.
state
.
textinfo
=
f
"Preparing dataset from {html.escape(data_root)}..."
shared
.
state
.
textinfo
=
f
"Preparing dataset from {html.escape(data_root)}..."
with
torch
.
autocast
(
"cuda"
):
with
torch
.
autocast
(
"cuda"
):
ds
=
modules
.
textual_inversion
.
dataset
.
PersonalizedBase
(
data_root
=
data_root
,
width
=
512
,
height
=
512
,
repeats
=
shared
.
opts
.
training_image_repeats_per_epoch
,
placeholder_token
=
hypernetwork_name
,
model
=
shared
.
sd_model
,
device
=
devices
.
device
,
template_file
=
template_file
,
include_cond
=
True
)
ds
=
modules
.
textual_inversion
.
dataset
.
PersonalizedBase
(
data_root
=
data_root
,
width
=
512
,
height
=
512
,
repeats
=
shared
.
opts
.
training_image_repeats_per_epoch
,
placeholder_token
=
hypernetwork_name
,
model
=
shared
.
sd_model
,
device
=
devices
.
device
,
template_file
=
template_file
,
include_cond
=
True
,
batch_size
=
batch_size
)
if
unload
:
if
unload
:
shared
.
sd_model
.
cond_stage_model
.
to
(
devices
.
cpu
)
shared
.
sd_model
.
cond_stage_model
.
to
(
devices
.
cpu
)
...
@@ -235,7 +249,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
...
@@ -235,7 +249,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
optimizer
=
torch
.
optim
.
AdamW
(
weights
,
lr
=
scheduler
.
learn_rate
)
optimizer
=
torch
.
optim
.
AdamW
(
weights
,
lr
=
scheduler
.
learn_rate
)
pbar
=
tqdm
.
tqdm
(
enumerate
(
ds
),
total
=
steps
-
ititial_step
)
pbar
=
tqdm
.
tqdm
(
enumerate
(
ds
),
total
=
steps
-
ititial_step
)
for
i
,
entr
y
in
pbar
:
for
i
,
entr
ies
in
pbar
:
hypernetwork
.
step
=
i
+
ititial_step
hypernetwork
.
step
=
i
+
ititial_step
scheduler
.
apply
(
optimizer
,
hypernetwork
.
step
)
scheduler
.
apply
(
optimizer
,
hypernetwork
.
step
)
...
@@ -246,11 +260,12 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
...
@@ -246,11 +260,12 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
break
break
with
torch
.
autocast
(
"cuda"
):
with
torch
.
autocast
(
"cuda"
):
cond
=
entry
.
cond
.
to
(
devices
.
device
)
c
=
stack_conds
([
entry
.
cond
for
entry
in
entries
])
.
to
(
devices
.
device
)
x
=
entry
.
latent
.
to
(
devices
.
device
)
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
loss
=
shared
.
sd_model
(
x
.
unsqueeze
(
0
),
cond
)[
0
]
x
=
torch
.
stack
([
entry
.
latent
for
entry
in
entries
])
.
to
(
devices
.
device
)
loss
=
shared
.
sd_model
(
x
,
c
)[
0
]
del
x
del
x
del
c
ond
del
c
losses
[
hypernetwork
.
step
%
losses
.
shape
[
0
]]
=
loss
.
item
()
losses
[
hypernetwork
.
step
%
losses
.
shape
[
0
]]
=
loss
.
item
()
...
@@ -292,7 +307,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
...
@@ -292,7 +307,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
p
.
width
=
preview_width
p
.
width
=
preview_width
p
.
height
=
preview_height
p
.
height
=
preview_height
else
:
else
:
p
.
prompt
=
entr
y
.
cond_text
p
.
prompt
=
entr
ies
[
0
]
.
cond_text
p
.
steps
=
20
p
.
steps
=
20
preview_text
=
p
.
prompt
preview_text
=
p
.
prompt
...
@@ -315,7 +330,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
...
@@ -315,7 +330,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
<p>
<p>
Loss: {losses.mean():.7f}<br/>
Loss: {losses.mean():.7f}<br/>
Step: {hypernetwork.step}<br/>
Step: {hypernetwork.step}<br/>
Last prompt: {html.escape(entr
y
.cond_text)}<br/>
Last prompt: {html.escape(entr
ies[0]
.cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
</p>
...
...
modules/images_history.py
View file @
ea8aa170
import
os
import
os
import
shutil
import
shutil
def
traverse_all_files
(
output_dir
,
image_list
,
curr_dir
=
None
):
def
traverse_all_files
(
output_dir
,
image_list
,
curr_dir
=
None
):
curr_path
=
output_dir
if
curr_dir
is
None
else
os
.
path
.
join
(
output_dir
,
curr_dir
)
curr_path
=
output_dir
if
curr_dir
is
None
else
os
.
path
.
join
(
output_dir
,
curr_dir
)
try
:
try
:
...
@@ -16,10 +18,10 @@ def traverse_all_files(output_dir, image_list, curr_dir=None):
...
@@ -16,10 +18,10 @@ def traverse_all_files(output_dir, image_list, curr_dir=None):
elif
os
.
path
.
isfile
(
file_path
)
and
file
[
-
10
:]
.
rfind
(
"."
)
>
0
:
elif
os
.
path
.
isfile
(
file_path
)
and
file
[
-
10
:]
.
rfind
(
"."
)
>
0
:
image_list
.
append
(
file
)
image_list
.
append
(
file
)
else
:
else
:
image_list
=
traverse_all_files
(
output_dir
,
image_list
,
file
)
image_list
=
traverse_all_files
(
output_dir
,
image_list
,
file
)
return
image_list
return
image_list
def
get_recent_images
(
dir_name
,
page_index
,
step
,
image_index
,
tabname
):
def
get_recent_images
(
dir_name
,
page_index
,
step
,
image_index
,
tabname
):
page_index
=
int
(
page_index
)
page_index
=
int
(
page_index
)
f_list
=
os
.
listdir
(
dir_name
)
f_list
=
os
.
listdir
(
dir_name
)
...
@@ -27,36 +29,48 @@ def get_recent_images(dir_name, page_index, step, image_index, tabname):
...
@@ -27,36 +29,48 @@ def get_recent_images(dir_name, page_index, step, image_index, tabname):
image_list
=
traverse_all_files
(
dir_name
,
image_list
)
image_list
=
traverse_all_files
(
dir_name
,
image_list
)
image_list
=
sorted
(
image_list
,
key
=
lambda
file
:
-
os
.
path
.
getctime
(
os
.
path
.
join
(
dir_name
,
file
)))
image_list
=
sorted
(
image_list
,
key
=
lambda
file
:
-
os
.
path
.
getctime
(
os
.
path
.
join
(
dir_name
,
file
)))
num
=
48
if
tabname
!=
"extras"
else
12
num
=
48
if
tabname
!=
"extras"
else
12
max_page_index
=
len
(
image_list
)
//
num
+
1
max_page_index
=
len
(
image_list
)
//
num
+
1
page_index
=
max_page_index
if
page_index
==
-
1
else
page_index
+
step
page_index
=
max_page_index
if
page_index
==
-
1
else
page_index
+
step
page_index
=
1
if
page_index
<
1
else
page_index
page_index
=
1
if
page_index
<
1
else
page_index
page_index
=
max_page_index
if
page_index
>
max_page_index
else
page_index
page_index
=
max_page_index
if
page_index
>
max_page_index
else
page_index
idx_frm
=
(
page_index
-
1
)
*
num
idx_frm
=
(
page_index
-
1
)
*
num
image_list
=
image_list
[
idx_frm
:
idx_frm
+
num
]
image_list
=
image_list
[
idx_frm
:
idx_frm
+
num
]
image_index
=
int
(
image_index
)
image_index
=
int
(
image_index
)
if
image_index
<
0
or
image_index
>
len
(
image_list
)
-
1
:
if
image_index
<
0
or
image_index
>
len
(
image_list
)
-
1
:
current_file
=
None
current_file
=
None
hidden
=
None
hidden
=
None
else
:
else
:
current_file
=
image_list
[
int
(
image_index
)]
current_file
=
image_list
[
int
(
image_index
)]
hidden
=
os
.
path
.
join
(
dir_name
,
current_file
)
hidden
=
os
.
path
.
join
(
dir_name
,
current_file
)
return
[
os
.
path
.
join
(
dir_name
,
file
)
for
file
in
image_list
],
page_index
,
image_list
,
current_file
,
hidden
,
""
return
[
os
.
path
.
join
(
dir_name
,
file
)
for
file
in
image_list
],
page_index
,
image_list
,
current_file
,
hidden
,
""
def
first_page_click
(
dir_name
,
page_index
,
image_index
,
tabname
):
def
first_page_click
(
dir_name
,
page_index
,
image_index
,
tabname
):
return
get_recent_images
(
dir_name
,
1
,
0
,
image_index
,
tabname
)
return
get_recent_images
(
dir_name
,
1
,
0
,
image_index
,
tabname
)
def
end_page_click
(
dir_name
,
page_index
,
image_index
,
tabname
):
def
end_page_click
(
dir_name
,
page_index
,
image_index
,
tabname
):
return
get_recent_images
(
dir_name
,
-
1
,
0
,
image_index
,
tabname
)
return
get_recent_images
(
dir_name
,
-
1
,
0
,
image_index
,
tabname
)
def
prev_page_click
(
dir_name
,
page_index
,
image_index
,
tabname
):
def
prev_page_click
(
dir_name
,
page_index
,
image_index
,
tabname
):
return
get_recent_images
(
dir_name
,
page_index
,
-
1
,
image_index
,
tabname
)
return
get_recent_images
(
dir_name
,
page_index
,
-
1
,
image_index
,
tabname
)
def
next_page_click
(
dir_name
,
page_index
,
image_index
,
tabname
):
def
next_page_click
(
dir_name
,
page_index
,
image_index
,
tabname
):
return
get_recent_images
(
dir_name
,
page_index
,
1
,
image_index
,
tabname
)
return
get_recent_images
(
dir_name
,
page_index
,
1
,
image_index
,
tabname
)
def
page_index_change
(
dir_name
,
page_index
,
image_index
,
tabname
):
def
page_index_change
(
dir_name
,
page_index
,
image_index
,
tabname
):
return
get_recent_images
(
dir_name
,
page_index
,
0
,
image_index
,
tabname
)
return
get_recent_images
(
dir_name
,
page_index
,
0
,
image_index
,
tabname
)
def
show_image_info
(
num
,
image_path
,
filenames
):
def
show_image_info
(
num
,
image_path
,
filenames
):
#print(f"select image {num}")
#
print(f"select image {num}")
file
=
filenames
[
int
(
num
)]
file
=
filenames
[
int
(
num
)]
return
file
,
num
,
os
.
path
.
join
(
image_path
,
file
)
return
file
,
num
,
os
.
path
.
join
(
image_path
,
file
)
def
delete_image
(
delete_num
,
tabname
,
dir_name
,
name
,
page_index
,
filenames
,
image_index
):
def
delete_image
(
delete_num
,
tabname
,
dir_name
,
name
,
page_index
,
filenames
,
image_index
):
if
name
==
""
:
if
name
==
""
:
return
filenames
,
delete_num
return
filenames
,
delete_num
...
@@ -66,14 +80,14 @@ def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, ima
...
@@ -66,14 +80,14 @@ def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, ima
i
=
0
i
=
0
new_file_list
=
[]
new_file_list
=
[]
for
name
in
filenames
:
for
name
in
filenames
:
if
i
>=
index
and
i
<
index
+
delete_num
:
if
i
>=
index
and
i
<
index
+
delete_num
:
path
=
os
.
path
.
join
(
dir_name
,
name
)
path
=
os
.
path
.
join
(
dir_name
,
name
)
if
os
.
path
.
exists
(
path
):
if
os
.
path
.
exists
(
path
):
print
(
f
"Delete file {path}"
)
print
(
f
"Delete file {path}"
)
os
.
remove
(
path
)
os
.
remove
(
path
)
txt_file
=
os
.
path
.
splitext
(
path
)[
0
]
+
".txt"
txt_file
=
os
.
path
.
splitext
(
path
)[
0
]
+
".txt"
if
os
.
path
.
exists
(
txt_file
):
if
os
.
path
.
exists
(
txt_file
):
os
.
remove
(
txt_file
)
os
.
remove
(
txt_file
)
else
:
else
:
print
(
f
"Not exists file {path}"
)
print
(
f
"Not exists file {path}"
)
else
:
else
:
...
@@ -81,81 +95,83 @@ def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, ima
...
@@ -81,81 +95,83 @@ def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, ima
i
+=
1
i
+=
1
return
new_file_list
,
1
return
new_file_list
,
1
def
show_images_history
(
gr
,
opts
,
tabname
,
run_pnginfo
,
switch_dict
):
def
show_images_history
(
gr
,
opts
,
tabname
,
run_pnginfo
,
switch_dict
):
if
tabname
==
"txt2img"
:
if
opts
.
outdir_samples
!=
""
:
dir_name
=
opts
.
outdir_txt2img_samples
dir_name
=
opts
.
outdir_samples
elif
tabname
==
"img2img"
:
elif
tabname
==
"txt2img"
:
dir_name
=
opts
.
outdir_img2img_samples
dir_name
=
opts
.
outdir_txt2img_samples
elif
tabname
==
"extras"
:
elif
tabname
==
"img2img"
:
dir_name
=
opts
.
outdir_extras_samples
dir_name
=
opts
.
outdir_img2img_samples
d
=
dir_name
.
split
(
"/"
)
elif
tabname
==
"extras"
:
dir_name
=
d
[
0
]
dir_name
=
opts
.
outdir_extras_samples
for
p
in
d
[
1
:]:
d
=
dir_name
.
split
(
"/"
)
dir_name
=
os
.
path
.
join
(
dir_name
,
p
)
dir_name
=
"/"
if
dir_name
.
startswith
(
"/"
)
else
d
[
0
]
with
gr
.
Row
():
for
p
in
d
[
1
:]:
renew_page
=
gr
.
Button
(
'Renew Page'
,
elem_id
=
tabname
+
"_images_history_renew_page"
)
dir_name
=
os
.
path
.
join
(
dir_name
,
p
)
first_page
=
gr
.
Button
(
'First Page'
)
with
gr
.
Row
():
prev_page
=
gr
.
Button
(
'Prev Page'
)
renew_page
=
gr
.
Button
(
'Renew Page'
,
elem_id
=
tabname
+
"_images_history_renew_page"
)
page_index
=
gr
.
Number
(
value
=
1
,
label
=
"Page Index"
)
first_page
=
gr
.
Button
(
'First Page'
)
next_page
=
gr
.
Button
(
'Next Page'
)
prev_page
=
gr
.
Button
(
'Prev Page'
)
end_page
=
gr
.
Button
(
'End Page'
)
page_index
=
gr
.
Number
(
value
=
1
,
label
=
"Page Index"
)
with
gr
.
Row
(
elem_id
=
tabname
+
"_images_history"
):
next_page
=
gr
.
Button
(
'Next Page'
)
with
gr
.
Row
():
end_page
=
gr
.
Button
(
'End Page'
)
with
gr
.
Column
(
scale
=
2
):
with
gr
.
Row
(
elem_id
=
tabname
+
"_images_history"
):
history_gallery
=
gr
.
Gallery
(
show_label
=
False
,
elem_id
=
tabname
+
"_images_history_gallery"
)
.
style
(
grid
=
6
)
with
gr
.
Row
():
with
gr
.
Row
():
with
gr
.
Column
(
scale
=
2
):
delete_num
=
gr
.
Number
(
value
=
1
,
interactive
=
True
,
label
=
"number of images to delete consecutively next"
)
history_gallery
=
gr
.
Gallery
(
show_label
=
False
,
elem_id
=
tabname
+
"_images_history_gallery"
)
.
style
(
grid
=
6
)
delete
=
gr
.
Button
(
'Delete'
,
elem_id
=
tabname
+
"_images_history_del_button"
)
with
gr
.
Row
():
with
gr
.
Column
():
delete_num
=
gr
.
Number
(
value
=
1
,
interactive
=
True
,
label
=
"number of images to delete consecutively next"
)
with
gr
.
Row
():
delete
=
gr
.
Button
(
'Delete'
,
elem_id
=
tabname
+
"_images_history_del_button"
)
pnginfo_send_to_txt2img
=
gr
.
Button
(
'Send to txt2img'
)
with
gr
.
Column
():
pnginfo_send_to_img2img
=
gr
.
Button
(
'Send to img2img'
)
with
gr
.
Row
():
with
gr
.
Row
():
pnginfo_send_to_txt2img
=
gr
.
Button
(
'Send to txt2img'
)
with
gr
.
Column
():
pnginfo_send_to_img2img
=
gr
.
Button
(
'Send to img2img'
)
img_file_info
=
gr
.
Textbox
(
label
=
"Generate Info"
,
interactive
=
False
)
with
gr
.
Row
():
img_file_name
=
gr
.
Textbox
(
label
=
"File Name"
,
interactive
=
False
)
with
gr
.
Column
():
with
gr
.
Row
():
img_file_info
=
gr
.
Textbox
(
label
=
"Generate Info"
,
interactive
=
False
)
# hiden items
img_file_name
=
gr
.
Textbox
(
label
=
"File Name"
,
interactive
=
False
)
with
gr
.
Row
():
img_path
=
gr
.
Textbox
(
dir_name
.
rstrip
(
"/"
)
,
visible
=
False
)
# hiden items
tabname_box
=
gr
.
Textbox
(
tabname
,
visible
=
False
)
image_index
=
gr
.
Textbox
(
value
=-
1
,
visible
=
False
)
img_path
=
gr
.
Textbox
(
dir_name
.
rstrip
(
"/"
),
visible
=
False
)
set_index
=
gr
.
Button
(
'set_index'
,
elem_id
=
tabname
+
"_images_history_set_index"
,
visible
=
False
)
tabname_box
=
gr
.
Textbox
(
tabname
,
visible
=
False
)
filenames
=
gr
.
State
()
image_index
=
gr
.
Textbox
(
value
=-
1
,
visible
=
False
)
hidden
=
gr
.
Image
(
type
=
"pil"
,
visible
=
False
)
set_index
=
gr
.
Button
(
'set_index'
,
elem_id
=
tabname
+
"_images_history_set_index"
,
visible
=
False
)
info1
=
gr
.
Textbox
(
visible
=
False
)
filenames
=
gr
.
State
()
info2
=
gr
.
Textbox
(
visible
=
False
)
hidden
=
gr
.
Image
(
type
=
"pil"
,
visible
=
False
)
info1
=
gr
.
Textbox
(
visible
=
False
)
info2
=
gr
.
Textbox
(
visible
=
False
)
# turn pages
gallery_inputs
=
[
img_path
,
page_index
,
image_index
,
tabname_box
]
# turn pages
gallery_outputs
=
[
history_gallery
,
page_index
,
filenames
,
img_file_name
,
hidden
,
img_file_name
]
gallery_inputs
=
[
img_path
,
page_index
,
image_index
,
tabname_box
]
gallery_outputs
=
[
history_gallery
,
page_index
,
filenames
,
img_file_name
,
hidden
,
img_file_name
]
first_page
.
click
(
first_page_click
,
_js
=
"images_history_turnpage"
,
inputs
=
gallery_inputs
,
outputs
=
gallery_outputs
)
next_page
.
click
(
next_page_click
,
_js
=
"images_history_turnpage"
,
inputs
=
gallery_inputs
,
outputs
=
gallery_outputs
)
first_page
.
click
(
first_page_click
,
_js
=
"images_history_turnpage"
,
inputs
=
gallery_inputs
,
outputs
=
gallery_outputs
)
prev_page
.
click
(
prev_page_click
,
_js
=
"images_history_turnpage"
,
inputs
=
gallery_inputs
,
outputs
=
gallery_outputs
)
next_page
.
click
(
next_page_click
,
_js
=
"images_history_turnpage"
,
inputs
=
gallery_inputs
,
outputs
=
gallery_outputs
)
end_page
.
click
(
end_page_click
,
_js
=
"images_history_turnpage"
,
inputs
=
gallery_inputs
,
outputs
=
gallery_outputs
)
prev_page
.
click
(
prev_page_click
,
_js
=
"images_history_turnpage"
,
inputs
=
gallery_inputs
,
outputs
=
gallery_outputs
)
page_index
.
submit
(
page_index_change
,
_js
=
"images_history_turnpage"
,
inputs
=
gallery_inputs
,
outputs
=
gallery_outputs
)
end_page
.
click
(
end_page_click
,
_js
=
"images_history_turnpage"
,
inputs
=
gallery_inputs
,
outputs
=
gallery_outputs
)
renew_page
.
click
(
page_index_change
,
_js
=
"images_history_turnpage"
,
inputs
=
gallery_inputs
,
outputs
=
gallery_outputs
)
page_index
.
submit
(
page_index_change
,
_js
=
"images_history_turnpage"
,
inputs
=
gallery_inputs
,
outputs
=
gallery_outputs
)
#page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index])
renew_page
.
click
(
page_index_change
,
_js
=
"images_history_turnpage"
,
inputs
=
gallery_inputs
,
outputs
=
gallery_outputs
)
# page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index])
#other funcitons
set_index
.
click
(
show_image_info
,
_js
=
"images_history_get_current_img"
,
inputs
=
[
tabname_box
,
img_path
,
filenames
],
outputs
=
[
img_file_name
,
image_index
,
hidden
])
# other funcitons
img_file_name
.
change
(
fn
=
None
,
_js
=
"images_history_enable_del_buttons"
,
inputs
=
None
,
outputs
=
None
)
set_index
.
click
(
show_image_info
,
_js
=
"images_history_get_current_img"
,
inputs
=
[
tabname_box
,
img_path
,
filenames
],
outputs
=
[
img_file_name
,
image_index
,
hidden
])
delete
.
click
(
delete_image
,
_js
=
"images_history_delete"
,
inputs
=
[
delete_num
,
tabname_box
,
img_path
,
img_file_name
,
page_index
,
filenames
,
image_index
],
outputs
=
[
filenames
,
delete_num
])
img_file_name
.
change
(
fn
=
None
,
_js
=
"images_history_enable_del_buttons"
,
inputs
=
None
,
outputs
=
None
)
hidden
.
change
(
fn
=
run_pnginfo
,
inputs
=
[
hidden
],
outputs
=
[
info1
,
img_file_info
,
info2
])
delete
.
click
(
delete_image
,
_js
=
"images_history_delete"
,
inputs
=
[
delete_num
,
tabname_box
,
img_path
,
img_file_name
,
page_index
,
filenames
,
image_index
],
outputs
=
[
filenames
,
delete_num
])
hidden
.
change
(
fn
=
run_pnginfo
,
inputs
=
[
hidden
],
outputs
=
[
info1
,
img_file_info
,
info2
])
#pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
switch_dict
[
"fn"
](
pnginfo_send_to_txt2img
,
switch_dict
[
"t2i"
],
img_file_info
,
'switch_to_txt2img'
)
# pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
switch_dict
[
"fn"
](
pnginfo_send_to_img2img
,
switch_dict
[
"i2i"
],
img_file_info
,
'switch_to_img2img_img2img'
)
switch_dict
[
"fn"
](
pnginfo_send_to_txt2img
,
switch_dict
[
"t2i"
],
img_file_info
,
'switch_to_txt2img'
)
switch_dict
[
"fn"
](
pnginfo_send_to_img2img
,
switch_dict
[
"i2i"
],
img_file_info
,
'switch_to_img2img_img2img'
)
def
create_history_tabs
(
gr
,
opts
,
run_pnginfo
,
switch_dict
):
def
create_history_tabs
(
gr
,
opts
,
run_pnginfo
,
switch_dict
):
with
gr
.
Blocks
(
analytics_enabled
=
False
)
as
images_history
:
with
gr
.
Blocks
(
analytics_enabled
=
False
)
as
images_history
:
with
gr
.
Tabs
()
as
tabs
:
with
gr
.
Tabs
()
as
tabs
:
with
gr
.
Tab
(
"txt2img history"
):
with
gr
.
Tab
(
"txt2img history"
):
with
gr
.
Blocks
(
analytics_enabled
=
False
)
as
images_history_txt2img
:
with
gr
.
Blocks
(
analytics_enabled
=
False
)
as
images_history_txt2img
:
show_images_history
(
gr
,
opts
,
"txt2img"
,
run_pnginfo
,
switch_dict
)
show_images_history
(
gr
,
opts
,
"txt2img"
,
run_pnginfo
,
switch_dict
)
with
gr
.
Tab
(
"img2img history"
):
with
gr
.
Tab
(
"img2img history"
):
with
gr
.
Blocks
(
analytics_enabled
=
False
)
as
images_history_img2img
:
with
gr
.
Blocks
(
analytics_enabled
=
False
)
as
images_history_img2img
:
show_images_history
(
gr
,
opts
,
"img2img"
,
run_pnginfo
,
switch_dict
)
show_images_history
(
gr
,
opts
,
"img2img"
,
run_pnginfo
,
switch_dict
)
...
...
modules/processing.py
View file @
ea8aa170
...
@@ -140,7 +140,7 @@ class Processed:
...
@@ -140,7 +140,7 @@ class Processed:
self
.
sampler_noise_scheduler_override
=
p
.
sampler_noise_scheduler_override
self
.
sampler_noise_scheduler_override
=
p
.
sampler_noise_scheduler_override
self
.
prompt
=
self
.
prompt
if
type
(
self
.
prompt
)
!=
list
else
self
.
prompt
[
0
]
self
.
prompt
=
self
.
prompt
if
type
(
self
.
prompt
)
!=
list
else
self
.
prompt
[
0
]
self
.
negative_prompt
=
self
.
negative_prompt
if
type
(
self
.
negative_prompt
)
!=
list
else
self
.
negative_prompt
[
0
]
self
.
negative_prompt
=
self
.
negative_prompt
if
type
(
self
.
negative_prompt
)
!=
list
else
self
.
negative_prompt
[
0
]
self
.
seed
=
int
(
self
.
seed
if
type
(
self
.
seed
)
!=
list
else
self
.
seed
[
0
])
self
.
seed
=
int
(
self
.
seed
if
type
(
self
.
seed
)
!=
list
else
self
.
seed
[
0
])
if
self
.
seed
is
not
None
else
-
1
self
.
subseed
=
int
(
self
.
subseed
if
type
(
self
.
subseed
)
!=
list
else
self
.
subseed
[
0
])
if
self
.
subseed
is
not
None
else
-
1
self
.
subseed
=
int
(
self
.
subseed
if
type
(
self
.
subseed
)
!=
list
else
self
.
subseed
[
0
])
if
self
.
subseed
is
not
None
else
-
1
self
.
all_prompts
=
all_prompts
or
[
self
.
prompt
]
self
.
all_prompts
=
all_prompts
or
[
self
.
prompt
]
...
...
modules/textual_inversion/dataset.py
View file @
ea8aa170
...
@@ -24,11 +24,12 @@ class DatasetEntry:
...
@@ -24,11 +24,12 @@ class DatasetEntry:
class
PersonalizedBase
(
Dataset
):
class
PersonalizedBase
(
Dataset
):
def
__init__
(
self
,
data_root
,
width
,
height
,
repeats
,
flip_p
=
0.5
,
placeholder_token
=
"*"
,
model
=
None
,
device
=
None
,
template_file
=
None
,
include_cond
=
False
):
def
__init__
(
self
,
data_root
,
width
,
height
,
repeats
,
flip_p
=
0.5
,
placeholder_token
=
"*"
,
model
=
None
,
device
=
None
,
template_file
=
None
,
include_cond
=
False
,
batch_size
=
1
):
re_word
=
re
.
compile
(
shared
.
opts
.
dataset_filename_word_regex
)
if
len
(
shared
.
opts
.
dataset_filename_word_regex
)
>
0
else
None
re_word
=
re
.
compile
(
shared
.
opts
.
dataset_filename_word_regex
)
if
len
(
shared
.
opts
.
dataset_filename_word_regex
)
>
0
else
None
self
.
placeholder_token
=
placeholder_token
self
.
placeholder_token
=
placeholder_token
self
.
batch_size
=
batch_size
self
.
width
=
width
self
.
width
=
width
self
.
height
=
height
self
.
height
=
height
self
.
flip
=
transforms
.
RandomHorizontalFlip
(
p
=
flip_p
)
self
.
flip
=
transforms
.
RandomHorizontalFlip
(
p
=
flip_p
)
...
@@ -78,14 +79,14 @@ class PersonalizedBase(Dataset):
...
@@ -78,14 +79,14 @@ class PersonalizedBase(Dataset):
if
include_cond
:
if
include_cond
:
entry
.
cond_text
=
self
.
create_text
(
filename_text
)
entry
.
cond_text
=
self
.
create_text
(
filename_text
)
entry
.
cond
=
cond_model
([
entry
.
cond_text
])
.
to
(
devices
.
cpu
)
entry
.
cond
=
cond_model
([
entry
.
cond_text
])
.
to
(
devices
.
cpu
)
.
squeeze
(
0
)
self
.
dataset
.
append
(
entry
)
self
.
dataset
.
append
(
entry
)
assert
len
(
self
.
dataset
)
>
1
,
"No images have been found in the dataset."
assert
len
(
self
.
dataset
)
>
1
,
"No images have been found in the dataset."
self
.
length
=
len
(
self
.
dataset
)
*
repeats
self
.
length
=
len
(
self
.
dataset
)
*
repeats
//
batch_size
self
.
initial_indexes
=
np
.
arange
(
self
.
length
)
%
len
(
self
.
dataset
)
self
.
initial_indexes
=
np
.
arange
(
len
(
self
.
dataset
)
)
self
.
indexes
=
None
self
.
indexes
=
None
self
.
shuffle
()
self
.
shuffle
()
...
@@ -102,13 +103,19 @@ class PersonalizedBase(Dataset):
...
@@ -102,13 +103,19 @@ class PersonalizedBase(Dataset):
return
self
.
length
return
self
.
length
def
__getitem__
(
self
,
i
):
def
__getitem__
(
self
,
i
):
if
i
%
len
(
self
.
dataset
)
==
0
:
res
=
[]
self
.
shuffle
()
for
j
in
range
(
self
.
batch_size
):
position
=
i
*
self
.
batch_size
+
j
if
position
%
len
(
self
.
indexes
)
==
0
:
self
.
shuffle
()
index
=
self
.
indexes
[
position
%
len
(
self
.
indexes
)]
entry
=
self
.
dataset
[
index
]
index
=
self
.
indexes
[
i
%
len
(
self
.
indexes
)]
if
entry
.
cond
is
None
:
entry
=
self
.
dataset
[
index
]
entry
.
cond_text
=
self
.
create_text
(
entry
.
filename_text
)
if
entry
.
cond
is
None
:
res
.
append
(
entry
)
entry
.
cond_text
=
self
.
create_text
(
entry
.
filename_text
)
return
entry
return
res
modules/textual_inversion/textual_inversion.py
View file @
ea8aa170
...
@@ -199,7 +199,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
...
@@ -199,7 +199,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
})
})
def
train_embedding
(
embedding_name
,
learn_rate
,
data_root
,
log_directory
,
training_width
,
training_height
,
steps
,
create_image_every
,
save_embedding_every
,
template_file
,
save_image_with_stored_embedding
,
preview_from_txt2img
,
preview_prompt
,
preview_negative_prompt
,
preview_steps
,
preview_sampler_index
,
preview_cfg_scale
,
preview_seed
,
preview_width
,
preview_height
):
def
train_embedding
(
embedding_name
,
learn_rate
,
batch_size
,
data_root
,
log_directory
,
training_width
,
training_height
,
steps
,
create_image_every
,
save_embedding_every
,
template_file
,
save_image_with_stored_embedding
,
preview_from_txt2img
,
preview_prompt
,
preview_negative_prompt
,
preview_steps
,
preview_sampler_index
,
preview_cfg_scale
,
preview_seed
,
preview_width
,
preview_height
):
assert
embedding_name
,
'embedding not selected'
assert
embedding_name
,
'embedding not selected'
shared
.
state
.
textinfo
=
"Initializing textual inversion training..."
shared
.
state
.
textinfo
=
"Initializing textual inversion training..."
...
@@ -231,7 +231,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
...
@@ -231,7 +231,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared
.
state
.
textinfo
=
f
"Preparing dataset from {html.escape(data_root)}..."
shared
.
state
.
textinfo
=
f
"Preparing dataset from {html.escape(data_root)}..."
with
torch
.
autocast
(
"cuda"
):
with
torch
.
autocast
(
"cuda"
):
ds
=
modules
.
textual_inversion
.
dataset
.
PersonalizedBase
(
data_root
=
data_root
,
width
=
training_width
,
height
=
training_height
,
repeats
=
shared
.
opts
.
training_image_repeats_per_epoch
,
placeholder_token
=
embedding_name
,
model
=
shared
.
sd_model
,
device
=
devices
.
device
,
template_file
=
template_file
)
ds
=
modules
.
textual_inversion
.
dataset
.
PersonalizedBase
(
data_root
=
data_root
,
width
=
training_width
,
height
=
training_height
,
repeats
=
shared
.
opts
.
training_image_repeats_per_epoch
,
placeholder_token
=
embedding_name
,
model
=
shared
.
sd_model
,
device
=
devices
.
device
,
template_file
=
template_file
,
batch_size
=
batch_size
)
hijack
=
sd_hijack
.
model_hijack
hijack
=
sd_hijack
.
model_hijack
...
@@ -251,7 +251,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
...
@@ -251,7 +251,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
optimizer
=
torch
.
optim
.
AdamW
([
embedding
.
vec
],
lr
=
scheduler
.
learn_rate
)
optimizer
=
torch
.
optim
.
AdamW
([
embedding
.
vec
],
lr
=
scheduler
.
learn_rate
)
pbar
=
tqdm
.
tqdm
(
enumerate
(
ds
),
total
=
steps
-
ititial_step
)
pbar
=
tqdm
.
tqdm
(
enumerate
(
ds
),
total
=
steps
-
ititial_step
)
for
i
,
entr
y
in
pbar
:
for
i
,
entr
ies
in
pbar
:
embedding
.
step
=
i
+
ititial_step
embedding
.
step
=
i
+
ititial_step
scheduler
.
apply
(
optimizer
,
embedding
.
step
)
scheduler
.
apply
(
optimizer
,
embedding
.
step
)
...
@@ -262,10 +262,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
...
@@ -262,10 +262,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
break
break
with
torch
.
autocast
(
"cuda"
):
with
torch
.
autocast
(
"cuda"
):
c
=
cond_model
([
entry
.
cond_text
])
c
=
cond_model
([
entry
.
cond_text
for
entry
in
entries
])
x
=
torch
.
stack
([
entry
.
latent
for
entry
in
entries
])
.
to
(
devices
.
device
)
x
=
entry
.
latent
.
to
(
devices
.
device
)
loss
=
shared
.
sd_model
(
x
,
c
)[
0
]
loss
=
shared
.
sd_model
(
x
.
unsqueeze
(
0
),
c
)[
0
]
del
x
del
x
losses
[
embedding
.
step
%
losses
.
shape
[
0
]]
=
loss
.
item
()
losses
[
embedding
.
step
%
losses
.
shape
[
0
]]
=
loss
.
item
()
...
@@ -307,7 +306,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
...
@@ -307,7 +306,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
p
.
width
=
preview_width
p
.
width
=
preview_width
p
.
height
=
preview_height
p
.
height
=
preview_height
else
:
else
:
p
.
prompt
=
entr
y
.
cond_text
p
.
prompt
=
entr
ies
[
0
]
.
cond_text
p
.
steps
=
20
p
.
steps
=
20
p
.
width
=
training_width
p
.
width
=
training_width
p
.
height
=
training_height
p
.
height
=
training_height
...
@@ -348,7 +347,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
...
@@ -348,7 +347,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
<p>
<p>
Loss: {losses.mean():.7f}<br/>
Loss: {losses.mean():.7f}<br/>
Step: {embedding.step}<br/>
Step: {embedding.step}<br/>
Last prompt: {html.escape(entr
y
.cond_text)}<br/>
Last prompt: {html.escape(entr
ies[0]
.cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
</p>
...
...
modules/ui.py
View file @
ea8aa170
...
@@ -433,7 +433,10 @@ def create_toprow(is_img2img):
...
@@ -433,7 +433,10 @@ def create_toprow(is_img2img):
with
gr
.
Row
():
with
gr
.
Row
():
with
gr
.
Column
(
scale
=
80
):
with
gr
.
Column
(
scale
=
80
):
with
gr
.
Row
():
with
gr
.
Row
():
prompt
=
gr
.
Textbox
(
label
=
"Prompt"
,
elem_id
=
f
"{id_part}_prompt"
,
show_label
=
False
,
placeholder
=
"Prompt"
,
lines
=
2
)
prompt
=
gr
.
Textbox
(
label
=
"Prompt"
,
elem_id
=
f
"{id_part}_prompt"
,
show_label
=
False
,
lines
=
2
,
placeholder
=
"Prompt (press Ctrl+Enter or Alt+Enter to generate)"
)
with
gr
.
Column
(
scale
=
1
,
elem_id
=
"roll_col"
):
with
gr
.
Column
(
scale
=
1
,
elem_id
=
"roll_col"
):
roll
=
gr
.
Button
(
value
=
art_symbol
,
elem_id
=
"roll"
,
visible
=
len
(
shared
.
artist_db
.
artists
)
>
0
)
roll
=
gr
.
Button
(
value
=
art_symbol
,
elem_id
=
"roll"
,
visible
=
len
(
shared
.
artist_db
.
artists
)
>
0
)
paste
=
gr
.
Button
(
value
=
paste_symbol
,
elem_id
=
"paste"
)
paste
=
gr
.
Button
(
value
=
paste_symbol
,
elem_id
=
"paste"
)
...
@@ -446,7 +449,10 @@ def create_toprow(is_img2img):
...
@@ -446,7 +449,10 @@ def create_toprow(is_img2img):
with
gr
.
Row
():
with
gr
.
Row
():
with
gr
.
Column
(
scale
=
8
):
with
gr
.
Column
(
scale
=
8
):
with
gr
.
Row
():
with
gr
.
Row
():
negative_prompt
=
gr
.
Textbox
(
label
=
"Negative prompt"
,
elem_id
=
"negative_prompt"
,
show_label
=
False
,
placeholder
=
"Negative prompt"
,
lines
=
2
)
negative_prompt
=
gr
.
Textbox
(
label
=
"Negative prompt"
,
elem_id
=
f
"{id_part}_neg_prompt"
,
show_label
=
False
,
lines
=
2
,
placeholder
=
"Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
)
with
gr
.
Column
(
scale
=
1
,
elem_id
=
"roll_col"
):
with
gr
.
Column
(
scale
=
1
,
elem_id
=
"roll_col"
):
sh
=
gr
.
Button
(
elem_id
=
"sh"
,
visible
=
True
)
sh
=
gr
.
Button
(
elem_id
=
"sh"
,
visible
=
True
)
...
@@ -567,8 +573,8 @@ def create_ui(wrap_gradio_gpu_call):
...
@@ -567,8 +573,8 @@ def create_ui(wrap_gradio_gpu_call):
enable_hr
=
gr
.
Checkbox
(
label
=
'Highres. fix'
,
value
=
False
)
enable_hr
=
gr
.
Checkbox
(
label
=
'Highres. fix'
,
value
=
False
)
with
gr
.
Row
(
visible
=
False
)
as
hr_options
:
with
gr
.
Row
(
visible
=
False
)
as
hr_options
:
firstphase_width
=
gr
.
Slider
(
minimum
=
0
,
maximum
=
1024
,
step
=
64
,
label
=
"First
pass width"
,
value
=
0
)
firstphase_width
=
gr
.
Slider
(
minimum
=
0
,
maximum
=
1024
,
step
=
64
,
label
=
"Firstpass width"
,
value
=
0
)
firstphase_height
=
gr
.
Slider
(
minimum
=
0
,
maximum
=
1024
,
step
=
64
,
label
=
"First
pass height"
,
value
=
0
)
firstphase_height
=
gr
.
Slider
(
minimum
=
0
,
maximum
=
1024
,
step
=
64
,
label
=
"Firstpass height"
,
value
=
0
)
denoising_strength
=
gr
.
Slider
(
minimum
=
0.0
,
maximum
=
1.0
,
step
=
0.01
,
label
=
'Denoising strength'
,
value
=
0.7
)
denoising_strength
=
gr
.
Slider
(
minimum
=
0.0
,
maximum
=
1.0
,
step
=
0.01
,
label
=
'Denoising strength'
,
value
=
0.7
)
with
gr
.
Row
(
equal_height
=
True
):
with
gr
.
Row
(
equal_height
=
True
):
...
@@ -1090,7 +1096,7 @@ def create_ui(wrap_gradio_gpu_call):
...
@@ -1090,7 +1096,7 @@ def create_ui(wrap_gradio_gpu_call):
"i2i"
:
img2img_paste_fields
"i2i"
:
img2img_paste_fields
}
}
#
images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
images_history
=
img_his
.
create_history_tabs
(
gr
,
opts
,
wrap_gradio_call
(
modules
.
extras
.
run_pnginfo
),
images_history_switch_dict
)
with
gr
.
Blocks
()
as
modelmerger_interface
:
with
gr
.
Blocks
()
as
modelmerger_interface
:
with
gr
.
Row
()
.
style
(
equal_height
=
False
):
with
gr
.
Row
()
.
style
(
equal_height
=
False
):
...
@@ -1166,6 +1172,7 @@ def create_ui(wrap_gradio_gpu_call):
...
@@ -1166,6 +1172,7 @@ def create_ui(wrap_gradio_gpu_call):
train_embedding_name
=
gr
.
Dropdown
(
label
=
'Embedding'
,
choices
=
sorted
(
sd_hijack
.
model_hijack
.
embedding_db
.
word_embeddings
.
keys
()))
train_embedding_name
=
gr
.
Dropdown
(
label
=
'Embedding'
,
choices
=
sorted
(
sd_hijack
.
model_hijack
.
embedding_db
.
word_embeddings
.
keys
()))
train_hypernetwork_name
=
gr
.
Dropdown
(
label
=
'Hypernetwork'
,
choices
=
[
x
for
x
in
shared
.
hypernetworks
.
keys
()])
train_hypernetwork_name
=
gr
.
Dropdown
(
label
=
'Hypernetwork'
,
choices
=
[
x
for
x
in
shared
.
hypernetworks
.
keys
()])
learn_rate
=
gr
.
Textbox
(
label
=
'Learning rate'
,
placeholder
=
"Learning rate"
,
value
=
"0.005"
)
learn_rate
=
gr
.
Textbox
(
label
=
'Learning rate'
,
placeholder
=
"Learning rate"
,
value
=
"0.005"
)
batch_size
=
gr
.
Number
(
label
=
'Batch size'
,
value
=
1
,
precision
=
0
)
dataset_directory
=
gr
.
Textbox
(
label
=
'Dataset directory'
,
placeholder
=
"Path to directory with input images"
)
dataset_directory
=
gr
.
Textbox
(
label
=
'Dataset directory'
,
placeholder
=
"Path to directory with input images"
)
log_directory
=
gr
.
Textbox
(
label
=
'Log directory'
,
placeholder
=
"Path to directory where to write outputs"
,
value
=
"textual_inversion"
)
log_directory
=
gr
.
Textbox
(
label
=
'Log directory'
,
placeholder
=
"Path to directory where to write outputs"
,
value
=
"textual_inversion"
)
template_file
=
gr
.
Textbox
(
label
=
'Prompt template file'
,
value
=
os
.
path
.
join
(
script_path
,
"textual_inversion_templates"
,
"style_filewords.txt"
))
template_file
=
gr
.
Textbox
(
label
=
'Prompt template file'
,
value
=
os
.
path
.
join
(
script_path
,
"textual_inversion_templates"
,
"style_filewords.txt"
))
...
@@ -1244,6 +1251,7 @@ def create_ui(wrap_gradio_gpu_call):
...
@@ -1244,6 +1251,7 @@ def create_ui(wrap_gradio_gpu_call):
inputs
=
[
inputs
=
[
train_embedding_name
,
train_embedding_name
,
learn_rate
,
learn_rate
,
batch_size
,
dataset_directory
,
dataset_directory
,
log_directory
,
log_directory
,
training_width
,
training_width
,
...
@@ -1268,6 +1276,7 @@ def create_ui(wrap_gradio_gpu_call):
...
@@ -1268,6 +1276,7 @@ def create_ui(wrap_gradio_gpu_call):
inputs
=
[
inputs
=
[
train_hypernetwork_name
,
train_hypernetwork_name
,
learn_rate
,
learn_rate
,
batch_size
,
dataset_directory
,
dataset_directory
,
log_directory
,
log_directory
,
steps
,
steps
,
...
@@ -1487,7 +1496,7 @@ Requested path was: {f}
...
@@ -1487,7 +1496,7 @@ Requested path was: {f}
(
img2img_interface
,
"img2img"
,
"img2img"
),
(
img2img_interface
,
"img2img"
,
"img2img"
),
(
extras_interface
,
"Extras"
,
"extras"
),
(
extras_interface
,
"Extras"
,
"extras"
),
(
pnginfo_interface
,
"PNG Info"
,
"pnginfo"
),
(
pnginfo_interface
,
"PNG Info"
,
"pnginfo"
),
#
(images_history, "History", "images_history"),
(
images_history
,
"History"
,
"images_history"
),
(
modelmerger_interface
,
"Checkpoint Merger"
,
"modelmerger"
),
(
modelmerger_interface
,
"Checkpoint Merger"
,
"modelmerger"
),
(
train_interface
,
"Train"
,
"ti"
),
(
train_interface
,
"Train"
,
"ti"
),
(
settings_interface
,
"Settings"
,
"settings"
),
(
settings_interface
,
"Settings"
,
"settings"
),
...
...
script.js
View file @
ea8aa170
...
@@ -50,9 +50,9 @@ document.addEventListener("DOMContentLoaded", function() {
...
@@ -50,9 +50,9 @@ document.addEventListener("DOMContentLoaded", function() {
document
.
addEventListener
(
'
keydown
'
,
function
(
e
)
{
document
.
addEventListener
(
'
keydown
'
,
function
(
e
)
{
var
handled
=
false
;
var
handled
=
false
;
if
(
e
.
key
!==
undefined
)
{
if
(
e
.
key
!==
undefined
)
{
if
((
e
.
key
==
"
Enter
"
&&
(
e
.
metaKey
||
e
.
ctrlKey
)))
handled
=
true
;
if
((
e
.
key
==
"
Enter
"
&&
(
e
.
metaKey
||
e
.
ctrlKey
||
e
.
altKey
)))
handled
=
true
;
}
else
if
(
e
.
keyCode
!==
undefined
)
{
}
else
if
(
e
.
keyCode
!==
undefined
)
{
if
((
e
.
keyCode
==
13
&&
(
e
.
metaKey
||
e
.
ctrlKey
)))
handled
=
true
;
if
((
e
.
keyCode
==
13
&&
(
e
.
metaKey
||
e
.
ctrlKey
||
e
.
altKey
)))
handled
=
true
;
}
}
if
(
handled
)
{
if
(
handled
)
{
button
=
get_uiCurrentTabContent
().
querySelector
(
'
button[id$=_generate]
'
);
button
=
get_uiCurrentTabContent
().
querySelector
(
'
button[id$=_generate]
'
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment