Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
79de09c3
Commit
79de09c3
authored
Jun 16, 2024
by
AUTOMATIC1111
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
linter
parent
5b2a60b8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
13 deletions
+16
-13
modules/models/sd3/other_impls.py
modules/models/sd3/other_impls.py
+10
-9
modules/models/sd3/sd3_impls.py
modules/models/sd3/sd3_impls.py
+6
-4
No files found.
modules/models/sd3/other_impls.py
View file @
79de09c3
### This file contains impls for underlying related models (CLIP, T5, etc)
import
torch
,
math
import
torch
import
math
from
torch
import
nn
from
transformers
import
CLIPTokenizer
,
T5TokenizerFast
...
...
@@ -14,7 +15,7 @@ def attention(q, k, v, heads, mask=None):
"""Convenience wrapper around a basic attention operation"""
b
,
_
,
dim_head
=
q
.
shape
dim_head
//=
heads
q
,
k
,
v
=
map
(
lambda
t
:
t
.
view
(
b
,
-
1
,
heads
,
dim_head
)
.
transpose
(
1
,
2
),
(
q
,
k
,
v
))
q
,
k
,
v
=
[
t
.
view
(
b
,
-
1
,
heads
,
dim_head
)
.
transpose
(
1
,
2
)
for
t
in
(
q
,
k
,
v
)]
out
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
mask
,
dropout_p
=
0.0
,
is_causal
=
False
)
return
out
.
transpose
(
1
,
2
)
.
reshape
(
b
,
-
1
,
heads
*
dim_head
)
...
...
@@ -89,8 +90,8 @@ class CLIPEncoder(torch.nn.Module):
if
intermediate_output
<
0
:
intermediate_output
=
len
(
self
.
layers
)
+
intermediate_output
intermediate
=
None
for
i
,
l
in
enumerate
(
self
.
layers
):
x
=
l
(
x
,
mask
)
for
i
,
l
ayer
in
enumerate
(
self
.
layers
):
x
=
l
ayer
(
x
,
mask
)
if
i
==
intermediate_output
:
intermediate
=
x
.
clone
()
return
x
,
intermediate
...
...
@@ -215,7 +216,7 @@ class SD3Tokenizer:
class
ClipTokenWeightEncoder
:
def
encode_token_weights
(
self
,
token_weight_pairs
):
tokens
=
list
(
map
(
lambda
a
:
a
[
0
],
token_weight_pairs
[
0
]))
tokens
=
[
a
[
0
]
for
a
in
token_weight_pairs
[
0
]]
out
,
pooled
=
self
([
tokens
])
if
pooled
is
not
None
:
first_pooled
=
pooled
[
0
:
1
]
.
cpu
()
...
...
@@ -229,7 +230,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS
=
[
"last"
,
"pooled"
,
"hidden"
]
def
__init__
(
self
,
device
=
"cpu"
,
max_length
=
77
,
layer
=
"last"
,
layer_idx
=
None
,
textmodel_json_config
=
None
,
dtype
=
None
,
model_class
=
CLIPTextModel
,
special_tokens
=
{
"start"
:
49406
,
"end"
:
49407
,
"pad"
:
49407
}
,
layer_norm_hidden_state
=
True
,
return_projected_pooled
=
True
):
special_tokens
=
None
,
layer_norm_hidden_state
=
True
,
return_projected_pooled
=
True
):
super
()
.
__init__
()
assert
layer
in
self
.
LAYERS
self
.
transformer
=
model_class
(
textmodel_json_config
,
dtype
,
device
)
...
...
@@ -240,7 +241,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
param
.
requires_grad
=
False
self
.
layer
=
layer
self
.
layer_idx
=
None
self
.
special_tokens
=
special_tokens
self
.
special_tokens
=
special_tokens
if
special_tokens
is
not
None
else
{
"start"
:
49406
,
"end"
:
49407
,
"pad"
:
49407
}
self
.
logit_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
4.6055
))
self
.
layer_norm_hidden_state
=
layer_norm_hidden_state
self
.
return_projected_pooled
=
return_projected_pooled
...
...
@@ -465,8 +466,8 @@ class T5Stack(torch.nn.Module):
intermediate
=
None
x
=
self
.
embed_tokens
(
input_ids
)
past_bias
=
None
for
i
,
l
in
enumerate
(
self
.
block
):
x
,
past_bias
=
l
(
x
,
past_bias
)
for
i
,
l
ayer
in
enumerate
(
self
.
block
):
x
,
past_bias
=
l
ayer
(
x
,
past_bias
)
if
i
==
intermediate_output
:
intermediate
=
x
.
clone
()
x
=
self
.
final_layer_norm
(
x
)
...
...
modules/models/sd3/sd3_impls.py
View file @
79de09c3
### Impls of the SD3 core diffusion model and VAE
import
torch
,
math
,
einops
import
torch
import
math
import
einops
from
modules.models.sd3.mmdit
import
MMDiT
from
PIL
import
Image
...
...
@@ -214,7 +216,7 @@ class AttnBlock(torch.nn.Module):
k
=
self
.
k
(
hidden
)
v
=
self
.
v
(
hidden
)
b
,
c
,
h
,
w
=
q
.
shape
q
,
k
,
v
=
map
(
lambda
x
:
einops
.
rearrange
(
x
,
"b c h w -> b 1 (h w) c"
)
.
contiguous
(),
(
q
,
k
,
v
))
q
,
k
,
v
=
[
einops
.
rearrange
(
x
,
"b c h w -> b 1 (h w) c"
)
.
contiguous
()
for
x
in
(
q
,
k
,
v
)]
hidden
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
q
,
k
,
v
)
# scale is dim ** -0.5 per default
hidden
=
einops
.
rearrange
(
hidden
,
"b 1 (h w) c -> b c h w"
,
h
=
h
,
w
=
w
,
c
=
c
,
b
=
b
)
hidden
=
self
.
proj_out
(
hidden
)
...
...
@@ -259,7 +261,7 @@ class VAEEncoder(torch.nn.Module):
attn
=
torch
.
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
num_res_blocks
):
for
_
in
range
(
num_res_blocks
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
dtype
=
dtype
,
device
=
device
))
block_in
=
block_out
down
=
torch
.
nn
.
Module
()
...
...
@@ -318,7 +320,7 @@ class VAEDecoder(torch.nn.Module):
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
torch
.
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
for
_
in
range
(
self
.
num_res_blocks
+
1
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
dtype
=
dtype
,
device
=
device
))
block_in
=
block_out
up
=
torch
.
nn
.
Module
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment