Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
d686e73d
Commit
d686e73d
authored
Jun 26, 2024
by
AUTOMATIC1111
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
support for SD3: infinite prompt length, token counting
parent
a8fba9af
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
278 additions
and
139 deletions
+278
-139
modules/models/sd3/sd3_cond.py
modules/models/sd3/sd3_cond.py
+225
-0
modules/models/sd3/sd3_model.py
modules/models/sd3/sd3_model.py
+2
-117
modules/prompt_parser.py
modules/prompt_parser.py
+1
-1
modules/sd_hijack.py
modules/sd_hijack.py
+4
-1
modules/sd_hijack_clip.py
modules/sd_hijack_clip.py
+40
-19
modules/sd_models.py
modules/sd_models.py
+6
-1
No files found.
modules/models/sd3/sd3_cond.py
0 → 100644
View file @
d686e73d
import
os
import
safetensors
import
torch
import
typing
from
transformers
import
CLIPTokenizer
,
T5TokenizerFast
from
modules
import
shared
,
devices
,
modelloader
,
sd_hijack_clip
,
prompt_parser
from
modules.models.sd3.other_impls
import
SDClipModel
,
SDXLClipG
,
T5XXLModel
,
SD3Tokenizer
class
SafetensorsMapping
(
typing
.
Mapping
):
def
__init__
(
self
,
file
):
self
.
file
=
file
def
__len__
(
self
):
return
len
(
self
.
file
.
keys
())
def
__iter__
(
self
):
for
key
in
self
.
file
.
keys
():
yield
key
def
__getitem__
(
self
,
key
):
return
self
.
file
.
get_tensor
(
key
)
CLIPL_URL
=
"https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_l.safetensors"
CLIPL_CONFIG
=
{
"hidden_act"
:
"quick_gelu"
,
"hidden_size"
:
768
,
"intermediate_size"
:
3072
,
"num_attention_heads"
:
12
,
"num_hidden_layers"
:
12
,
}
CLIPG_URL
=
"https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_g.safetensors"
CLIPG_CONFIG
=
{
"hidden_act"
:
"gelu"
,
"hidden_size"
:
1280
,
"intermediate_size"
:
5120
,
"num_attention_heads"
:
20
,
"num_hidden_layers"
:
32
,
}
T5_URL
=
"https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
T5_CONFIG
=
{
"d_ff"
:
10240
,
"d_model"
:
4096
,
"num_heads"
:
64
,
"num_layers"
:
24
,
"vocab_size"
:
32128
,
}
class
Sd3ClipLG
(
sd_hijack_clip
.
TextConditionalModel
):
def
__init__
(
self
,
clip_l
,
clip_g
):
super
()
.
__init__
()
self
.
clip_l
=
clip_l
self
.
clip_g
=
clip_g
self
.
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
"openai/clip-vit-large-patch14"
)
empty
=
self
.
tokenizer
(
''
)[
"input_ids"
]
self
.
id_start
=
empty
[
0
]
self
.
id_end
=
empty
[
1
]
self
.
id_pad
=
empty
[
1
]
self
.
return_pooled
=
True
def
tokenize
(
self
,
texts
):
return
self
.
tokenizer
(
texts
,
truncation
=
False
,
add_special_tokens
=
False
)[
"input_ids"
]
def
encode_with_transformers
(
self
,
tokens
):
tokens_g
=
tokens
.
clone
()
for
batch_pos
in
range
(
tokens_g
.
shape
[
0
]):
index
=
tokens_g
[
batch_pos
]
.
cpu
()
.
tolist
()
.
index
(
self
.
id_end
)
tokens_g
[
batch_pos
,
index
+
1
:
tokens_g
.
shape
[
1
]]
=
0
l_out
,
l_pooled
=
self
.
clip_l
(
tokens
)
g_out
,
g_pooled
=
self
.
clip_g
(
tokens_g
)
lg_out
=
torch
.
cat
([
l_out
,
g_out
],
dim
=-
1
)
lg_out
=
torch
.
nn
.
functional
.
pad
(
lg_out
,
(
0
,
4096
-
lg_out
.
shape
[
-
1
]))
vector_out
=
torch
.
cat
((
l_pooled
,
g_pooled
),
dim
=-
1
)
lg_out
.
pooled
=
vector_out
return
lg_out
def
encode_embedding_init_text
(
self
,
init_text
,
nvpt
):
return
torch
.
zeros
((
nvpt
,
768
+
1280
),
device
=
devices
.
device
)
# XXX
class
Sd3T5
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
t5xxl
):
super
()
.
__init__
()
self
.
t5xxl
=
t5xxl
self
.
tokenizer
=
T5TokenizerFast
.
from_pretrained
(
"google/t5-v1_1-xxl"
)
empty
=
self
.
tokenizer
(
''
,
padding
=
'max_length'
,
max_length
=
2
)[
"input_ids"
]
self
.
id_end
=
empty
[
0
]
self
.
id_pad
=
empty
[
1
]
def
tokenize
(
self
,
texts
):
return
self
.
tokenizer
(
texts
,
truncation
=
False
,
add_special_tokens
=
False
)[
"input_ids"
]
def
tokenize_line
(
self
,
line
,
*
,
target_token_count
=
None
):
if
shared
.
opts
.
emphasis
!=
"None"
:
parsed
=
prompt_parser
.
parse_prompt_attention
(
line
)
else
:
parsed
=
[[
line
,
1.0
]]
tokenized
=
self
.
tokenize
([
text
for
text
,
_
in
parsed
])
tokens
=
[]
multipliers
=
[]
for
text_tokens
,
(
text
,
weight
)
in
zip
(
tokenized
,
parsed
):
if
text
==
'BREAK'
and
weight
==
-
1
:
continue
tokens
+=
text_tokens
multipliers
+=
[
weight
]
*
len
(
text_tokens
)
tokens
+=
[
self
.
id_end
]
multipliers
+=
[
1.0
]
if
target_token_count
is
not
None
:
if
len
(
tokens
)
<
target_token_count
:
tokens
+=
[
self
.
id_pad
]
*
(
target_token_count
-
len
(
tokens
))
multipliers
+=
[
1.0
]
*
(
target_token_count
-
len
(
tokens
))
else
:
tokens
=
tokens
[
0
:
target_token_count
]
multipliers
=
multipliers
[
0
:
target_token_count
]
return
tokens
,
multipliers
def
forward
(
self
,
texts
,
*
,
token_count
):
if
not
self
.
t5xxl
or
not
shared
.
opts
.
sd3_enable_t5
:
return
torch
.
zeros
((
len
(
texts
),
token_count
,
4096
),
device
=
devices
.
device
,
dtype
=
devices
.
dtype
)
tokens_batch
=
[]
for
text
in
texts
:
tokens
,
multipliers
=
self
.
tokenize_line
(
text
,
target_token_count
=
token_count
)
tokens_batch
.
append
(
tokens
)
t5_out
,
t5_pooled
=
self
.
t5xxl
(
tokens_batch
)
return
t5_out
def
encode_embedding_init_text
(
self
,
init_text
,
nvpt
):
return
torch
.
zeros
((
nvpt
,
4096
),
device
=
devices
.
device
)
# XXX
class
SD3Cond
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
()
.
__init__
(
*
args
,
**
kwargs
)
self
.
tokenizer
=
SD3Tokenizer
()
with
torch
.
no_grad
():
self
.
clip_g
=
SDXLClipG
(
CLIPG_CONFIG
,
device
=
"cpu"
,
dtype
=
devices
.
dtype
)
self
.
clip_l
=
SDClipModel
(
layer
=
"hidden"
,
layer_idx
=-
2
,
device
=
"cpu"
,
dtype
=
devices
.
dtype
,
layer_norm_hidden_state
=
False
,
return_projected_pooled
=
False
,
textmodel_json_config
=
CLIPL_CONFIG
)
if
shared
.
opts
.
sd3_enable_t5
:
self
.
t5xxl
=
T5XXLModel
(
T5_CONFIG
,
device
=
"cpu"
,
dtype
=
devices
.
dtype
)
else
:
self
.
t5xxl
=
None
self
.
model_lg
=
Sd3ClipLG
(
self
.
clip_l
,
self
.
clip_g
)
self
.
model_t5
=
Sd3T5
(
self
.
t5xxl
)
self
.
weights_loaded
=
False
def
forward
(
self
,
prompts
:
list
[
str
]):
lg_out
,
vector_out
=
self
.
model_lg
(
prompts
)
token_count
=
lg_out
.
shape
[
1
]
t5_out
=
self
.
model_t5
(
prompts
,
token_count
=
token_count
)
lgt_out
=
torch
.
cat
([
lg_out
,
t5_out
],
dim
=-
2
)
return
{
'crossattn'
:
lgt_out
,
'vector'
:
vector_out
,
}
def
load_weights
(
self
):
if
self
.
weights_loaded
:
return
clip_path
=
os
.
path
.
join
(
shared
.
models_path
,
"CLIP"
)
clip_g_file
=
modelloader
.
load_file_from_url
(
CLIPG_URL
,
model_dir
=
clip_path
,
file_name
=
"clip_g.safetensors"
)
with
safetensors
.
safe_open
(
clip_g_file
,
framework
=
"pt"
)
as
file
:
self
.
clip_g
.
transformer
.
load_state_dict
(
SafetensorsMapping
(
file
))
clip_l_file
=
modelloader
.
load_file_from_url
(
CLIPL_URL
,
model_dir
=
clip_path
,
file_name
=
"clip_l.safetensors"
)
with
safetensors
.
safe_open
(
clip_l_file
,
framework
=
"pt"
)
as
file
:
self
.
clip_l
.
transformer
.
load_state_dict
(
SafetensorsMapping
(
file
),
strict
=
False
)
if
self
.
t5xxl
:
t5_file
=
modelloader
.
load_file_from_url
(
T5_URL
,
model_dir
=
clip_path
,
file_name
=
"t5xxl_fp16.safetensors"
)
with
safetensors
.
safe_open
(
t5_file
,
framework
=
"pt"
)
as
file
:
self
.
t5xxl
.
transformer
.
load_state_dict
(
SafetensorsMapping
(
file
),
strict
=
False
)
self
.
weights_loaded
=
True
def
encode_embedding_init_text
(
self
,
init_text
,
nvpt
):
return
torch
.
tensor
([[
0
]],
device
=
devices
.
device
)
# XXX
def
medvram_modules
(
self
):
return
[
self
.
clip_g
,
self
.
clip_l
,
self
.
t5xxl
]
def
get_token_count
(
self
,
text
):
_
,
token_count
=
self
.
model_lg
.
process_texts
([
text
])
return
token_count
def
get_target_prompt_token_count
(
self
,
token_count
):
return
self
.
model_lg
.
get_target_prompt_token_count
(
token_count
)
modules/models/sd3/sd3_model.py
View file @
d686e73d
import
contextlib
import
contextlib
import
os
from
typing
import
Mapping
import
safetensors
import
torch
import
torch
import
k_diffusion
import
k_diffusion
from
modules.models.sd3.other_impls
import
SDClipModel
,
SDXLClipG
,
T5XXLModel
,
SD3Tokenizer
from
modules.models.sd3.sd3_impls
import
BaseModel
,
SDVAE
,
SD3LatentFormat
from
modules.models.sd3.sd3_impls
import
BaseModel
,
SDVAE
,
SD3LatentFormat
from
modules.models.sd3.sd3_cond
import
SD3Cond
from
modules
import
shared
,
modelloader
,
devices
from
modules
import
shared
,
devices
CLIPG_URL
=
"https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_g.safetensors"
CLIPG_CONFIG
=
{
"hidden_act"
:
"gelu"
,
"hidden_size"
:
1280
,
"intermediate_size"
:
5120
,
"num_attention_heads"
:
20
,
"num_hidden_layers"
:
32
,
}
CLIPL_URL
=
"https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_l.safetensors"
CLIPL_CONFIG
=
{
"hidden_act"
:
"quick_gelu"
,
"hidden_size"
:
768
,
"intermediate_size"
:
3072
,
"num_attention_heads"
:
12
,
"num_hidden_layers"
:
12
,
}
T5_URL
=
"https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
T5_CONFIG
=
{
"d_ff"
:
10240
,
"d_model"
:
4096
,
"num_heads"
:
64
,
"num_layers"
:
24
,
"vocab_size"
:
32128
,
}
class
SafetensorsMapping
(
Mapping
):
def
__init__
(
self
,
file
):
self
.
file
=
file
def
__len__
(
self
):
return
len
(
self
.
file
.
keys
())
def
__iter__
(
self
):
for
key
in
self
.
file
.
keys
():
yield
key
def
__getitem__
(
self
,
key
):
return
self
.
file
.
get_tensor
(
key
)
class
SD3Cond
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
()
.
__init__
(
*
args
,
**
kwargs
)
self
.
tokenizer
=
SD3Tokenizer
()
with
torch
.
no_grad
():
self
.
clip_g
=
SDXLClipG
(
CLIPG_CONFIG
,
device
=
"cpu"
,
dtype
=
devices
.
dtype
)
self
.
clip_l
=
SDClipModel
(
layer
=
"hidden"
,
layer_idx
=-
2
,
device
=
"cpu"
,
dtype
=
devices
.
dtype
,
layer_norm_hidden_state
=
False
,
return_projected_pooled
=
False
,
textmodel_json_config
=
CLIPL_CONFIG
)
if
shared
.
opts
.
sd3_enable_t5
:
self
.
t5xxl
=
T5XXLModel
(
T5_CONFIG
,
device
=
"cpu"
,
dtype
=
devices
.
dtype
)
else
:
self
.
t5xxl
=
None
self
.
weights_loaded
=
False
def
forward
(
self
,
prompts
:
list
[
str
]):
res
=
[]
for
prompt
in
prompts
:
tokens
=
self
.
tokenizer
.
tokenize_with_weights
(
prompt
)
l_out
,
l_pooled
=
self
.
clip_l
.
encode_token_weights
(
tokens
[
"l"
])
g_out
,
g_pooled
=
self
.
clip_g
.
encode_token_weights
(
tokens
[
"g"
])
if
self
.
t5xxl
and
shared
.
opts
.
sd3_enable_t5
:
t5_out
,
t5_pooled
=
self
.
t5xxl
.
encode_token_weights
(
tokens
[
"t5xxl"
])
else
:
t5_out
=
torch
.
zeros
(
l_out
.
shape
[
0
:
2
]
+
(
4096
,),
dtype
=
l_out
.
dtype
,
device
=
l_out
.
device
)
lg_out
=
torch
.
cat
([
l_out
,
g_out
],
dim
=-
1
)
lg_out
=
torch
.
nn
.
functional
.
pad
(
lg_out
,
(
0
,
4096
-
lg_out
.
shape
[
-
1
]))
lgt_out
=
torch
.
cat
([
lg_out
,
t5_out
],
dim
=-
2
)
vector_out
=
torch
.
cat
((
l_pooled
,
g_pooled
),
dim
=-
1
)
res
.
append
({
'crossattn'
:
lgt_out
[
0
]
.
to
(
devices
.
device
),
'vector'
:
vector_out
[
0
]
.
to
(
devices
.
device
),
})
return
res
def
load_weights
(
self
):
if
self
.
weights_loaded
:
return
clip_path
=
os
.
path
.
join
(
shared
.
models_path
,
"CLIP"
)
clip_g_file
=
modelloader
.
load_file_from_url
(
CLIPG_URL
,
model_dir
=
clip_path
,
file_name
=
"clip_g.safetensors"
)
with
safetensors
.
safe_open
(
clip_g_file
,
framework
=
"pt"
)
as
file
:
self
.
clip_g
.
transformer
.
load_state_dict
(
SafetensorsMapping
(
file
))
clip_l_file
=
modelloader
.
load_file_from_url
(
CLIPL_URL
,
model_dir
=
clip_path
,
file_name
=
"clip_l.safetensors"
)
with
safetensors
.
safe_open
(
clip_l_file
,
framework
=
"pt"
)
as
file
:
self
.
clip_l
.
transformer
.
load_state_dict
(
SafetensorsMapping
(
file
),
strict
=
False
)
if
self
.
t5xxl
:
t5_file
=
modelloader
.
load_file_from_url
(
T5_URL
,
model_dir
=
clip_path
,
file_name
=
"t5xxl_fp16.safetensors"
)
with
safetensors
.
safe_open
(
t5_file
,
framework
=
"pt"
)
as
file
:
self
.
t5xxl
.
transformer
.
load_state_dict
(
SafetensorsMapping
(
file
),
strict
=
False
)
self
.
weights_loaded
=
True
def
encode_embedding_init_text
(
self
,
init_text
,
nvpt
):
return
torch
.
tensor
([[
0
]],
device
=
devices
.
device
)
# XXX
def
medvram_modules
(
self
):
return
[
self
.
clip_g
,
self
.
clip_l
,
self
.
t5xxl
]
class
SD3Denoiser
(
k_diffusion
.
external
.
DiscreteSchedule
):
class
SD3Denoiser
(
k_diffusion
.
external
.
DiscreteSchedule
):
...
...
modules/prompt_parser.py
View file @
d686e73d
...
@@ -268,7 +268,7 @@ def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None,
...
@@ -268,7 +268,7 @@ def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None,
class
DictWithShape
(
dict
):
class
DictWithShape
(
dict
):
def
__init__
(
self
,
x
,
shape
):
def
__init__
(
self
,
x
,
shape
=
None
):
super
()
.
__init__
()
super
()
.
__init__
()
self
.
update
(
x
)
self
.
update
(
x
)
...
...
modules/sd_hijack.py
View file @
d686e73d
...
@@ -325,7 +325,10 @@ class StableDiffusionModelHijack:
...
@@ -325,7 +325,10 @@ class StableDiffusionModelHijack:
if
self
.
clip
is
None
:
if
self
.
clip
is
None
:
return
"-"
,
"-"
return
"-"
,
"-"
_
,
token_count
=
self
.
clip
.
process_texts
([
text
])
if
hasattr
(
self
.
clip
,
'get_token_count'
):
token_count
=
self
.
clip
.
get_token_count
(
text
)
else
:
_
,
token_count
=
self
.
clip
.
process_texts
([
text
])
return
token_count
,
self
.
clip
.
get_target_prompt_token_count
(
token_count
)
return
token_count
,
self
.
clip
.
get_target_prompt_token_count
(
token_count
)
...
...
modules/sd_hijack_clip.py
View file @
d686e73d
...
@@ -27,24 +27,21 @@ chunk. Those objects are found in PromptChunk.fixes and, are placed into FrozenC
...
@@ -27,24 +27,21 @@ chunk. Those objects are found in PromptChunk.fixes and, are placed into FrozenC
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
class
FrozenCLIPEmbedderWithCustomWordsBase
(
torch
.
nn
.
Module
):
class
TextConditionalModel
(
torch
.
nn
.
Module
):
"""A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
def
__init__
(
self
):
have unlimited prompt length and assign weights to tokens in prompt.
"""
def
__init__
(
self
,
wrapped
,
hijack
):
super
()
.
__init__
()
super
()
.
__init__
()
self
.
wrapped
=
wrapped
self
.
hijack
=
sd_hijack
.
model_hijack
"""Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
depending on model."""
self
.
hijack
:
sd_hijack
.
StableDiffusionModelHijack
=
hijack
self
.
chunk_length
=
75
self
.
chunk_length
=
75
self
.
is_trainable
=
getattr
(
wrapped
,
'is_trainable'
,
False
)
self
.
is_trainable
=
False
self
.
input_key
=
getattr
(
wrapped
,
'input_key'
,
'txt'
)
self
.
input_key
=
'txt'
self
.
legacy_ucg_val
=
None
self
.
return_pooled
=
False
self
.
comma_token
=
None
self
.
id_start
=
None
self
.
id_end
=
None
self
.
id_pad
=
None
def
empty_chunk
(
self
):
def
empty_chunk
(
self
):
"""creates an empty PromptChunk and returns it"""
"""creates an empty PromptChunk and returns it"""
...
@@ -210,10 +207,6 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
...
@@ -210,10 +207,6 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
"""
"""
if
opts
.
use_old_emphasis_implementation
:
import
modules.sd_hijack_clip_old
return
modules
.
sd_hijack_clip_old
.
forward_old
(
self
,
texts
)
batch_chunks
,
token_count
=
self
.
process_texts
(
texts
)
batch_chunks
,
token_count
=
self
.
process_texts
(
texts
)
used_embeddings
=
{}
used_embeddings
=
{}
...
@@ -252,7 +245,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
...
@@ -252,7 +245,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
if
any
(
x
for
x
in
texts
if
"("
in
x
or
"["
in
x
)
and
opts
.
emphasis
!=
"Original"
:
if
any
(
x
for
x
in
texts
if
"("
in
x
or
"["
in
x
)
and
opts
.
emphasis
!=
"Original"
:
self
.
hijack
.
extra_generation_params
[
"Emphasis"
]
=
opts
.
emphasis
self
.
hijack
.
extra_generation_params
[
"Emphasis"
]
=
opts
.
emphasis
if
getattr
(
self
.
wrapped
,
'return_pooled'
,
False
)
:
if
self
.
return_pooled
:
return
torch
.
hstack
(
zs
),
zs
[
0
]
.
pooled
return
torch
.
hstack
(
zs
),
zs
[
0
]
.
pooled
else
:
else
:
return
torch
.
hstack
(
zs
)
return
torch
.
hstack
(
zs
)
...
@@ -292,6 +285,34 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
...
@@ -292,6 +285,34 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
return
z
return
z
class
FrozenCLIPEmbedderWithCustomWordsBase
(
TextConditionalModel
):
"""A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
have unlimited prompt length and assign weights to tokens in prompt.
"""
def
__init__
(
self
,
wrapped
,
hijack
):
super
()
.
__init__
()
self
.
hijack
=
hijack
self
.
wrapped
=
wrapped
"""Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
depending on model."""
self
.
is_trainable
=
getattr
(
wrapped
,
'is_trainable'
,
False
)
self
.
input_key
=
getattr
(
wrapped
,
'input_key'
,
'txt'
)
self
.
return_pooled
=
getattr
(
self
.
wrapped
,
'return_pooled'
,
False
)
self
.
legacy_ucg_val
=
None
# for sgm codebase
def
forward
(
self
,
texts
):
if
opts
.
use_old_emphasis_implementation
:
import
modules.sd_hijack_clip_old
return
modules
.
sd_hijack_clip_old
.
forward_old
(
self
,
texts
)
return
super
()
.
forward
(
texts
)
class
FrozenCLIPEmbedderWithCustomWords
(
FrozenCLIPEmbedderWithCustomWordsBase
):
class
FrozenCLIPEmbedderWithCustomWords
(
FrozenCLIPEmbedderWithCustomWordsBase
):
def
__init__
(
self
,
wrapped
,
hijack
):
def
__init__
(
self
,
wrapped
,
hijack
):
super
()
.
__init__
(
wrapped
,
hijack
)
super
()
.
__init__
(
wrapped
,
hijack
)
...
...
modules/sd_models.py
View file @
d686e73d
...
@@ -722,7 +722,12 @@ def get_empty_cond(sd_model):
...
@@ -722,7 +722,12 @@ def get_empty_cond(sd_model):
d
=
sd_model
.
get_learned_conditioning
([
""
])
d
=
sd_model
.
get_learned_conditioning
([
""
])
return
d
[
'crossattn'
]
return
d
[
'crossattn'
]
else
:
else
:
return
sd_model
.
cond_stage_model
([
""
])
d
=
sd_model
.
cond_stage_model
([
""
])
if
isinstance
(
d
,
dict
):
d
=
d
[
'crossattn'
]
return
d
def
send_model_to_cpu
(
m
):
def
send_model_to_cpu
(
m
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment