Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
a7116aa9
Commit
a7116aa9
authored
Jun 16, 2024
by
Alex "mcmonkey" Goodwin
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add SD3 reference implementation from
https://github.com/mcmonkey4eva/sd3-ref/
parent
a30b19dd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
1482 additions
and
0 deletions
+1482
-0
modules/models/sd3/mmdit.py
modules/models/sd3/mmdit.py
+619
-0
modules/models/sd3/other_impls.py
modules/models/sd3/other_impls.py
+492
-0
modules/models/sd3/sd3_impls.py
modules/models/sd3/sd3_impls.py
+371
-0
No files found.
modules/models/sd3/mmdit.py
0 → 100644
View file @
a7116aa9
### This file contains impls for MM-DiT, the core model component of SD3
import
math
from
typing
import
Dict
,
Optional
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
,
repeat
from
other_impls
import
attention
,
Mlp
class
PatchEmbed
(
nn
.
Module
):
""" 2D Image to Patch Embedding"""
def
__init__
(
self
,
img_size
:
Optional
[
int
]
=
224
,
patch_size
:
int
=
16
,
in_chans
:
int
=
3
,
embed_dim
:
int
=
768
,
flatten
:
bool
=
True
,
bias
:
bool
=
True
,
strict_img_size
:
bool
=
True
,
dynamic_img_pad
:
bool
=
False
,
dtype
=
None
,
device
=
None
,
):
super
()
.
__init__
()
self
.
patch_size
=
(
patch_size
,
patch_size
)
if
img_size
is
not
None
:
self
.
img_size
=
(
img_size
,
img_size
)
self
.
grid_size
=
tuple
([
s
//
p
for
s
,
p
in
zip
(
self
.
img_size
,
self
.
patch_size
)])
self
.
num_patches
=
self
.
grid_size
[
0
]
*
self
.
grid_size
[
1
]
else
:
self
.
img_size
=
None
self
.
grid_size
=
None
self
.
num_patches
=
None
# flatten spatial dim and transpose to channels last, kept for bwd compat
self
.
flatten
=
flatten
self
.
strict_img_size
=
strict_img_size
self
.
dynamic_img_pad
=
dynamic_img_pad
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
,
bias
=
bias
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
x
=
self
.
proj
(
x
)
if
self
.
flatten
:
x
=
x
.
flatten
(
2
)
.
transpose
(
1
,
2
)
# NCHW -> NLC
return
x
def
modulate
(
x
,
shift
,
scale
):
if
shift
is
None
:
shift
=
torch
.
zeros_like
(
scale
)
return
x
*
(
1
+
scale
.
unsqueeze
(
1
))
+
shift
.
unsqueeze
(
1
)
#################################################################################
# Sine/Cosine Positional Embedding Functions #
#################################################################################
def
get_2d_sincos_pos_embed
(
embed_dim
,
grid_size
,
cls_token
=
False
,
extra_tokens
=
0
,
scaling_factor
=
None
,
offset
=
None
):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h
=
np
.
arange
(
grid_size
,
dtype
=
np
.
float32
)
grid_w
=
np
.
arange
(
grid_size
,
dtype
=
np
.
float32
)
grid
=
np
.
meshgrid
(
grid_w
,
grid_h
)
# here w goes first
grid
=
np
.
stack
(
grid
,
axis
=
0
)
if
scaling_factor
is
not
None
:
grid
=
grid
/
scaling_factor
if
offset
is
not
None
:
grid
=
grid
-
offset
grid
=
grid
.
reshape
([
2
,
1
,
grid_size
,
grid_size
])
pos_embed
=
get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
)
if
cls_token
and
extra_tokens
>
0
:
pos_embed
=
np
.
concatenate
([
np
.
zeros
([
extra_tokens
,
embed_dim
]),
pos_embed
],
axis
=
0
)
return
pos_embed
def
get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
):
assert
embed_dim
%
2
==
0
# use half of dimensions to encode grid_h
emb_h
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
0
])
# (H*W, D/2)
emb_w
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
1
])
# (H*W, D/2)
emb
=
np
.
concatenate
([
emb_h
,
emb_w
],
axis
=
1
)
# (H*W, D)
return
emb
def
get_1d_sincos_pos_embed_from_grid
(
embed_dim
,
pos
):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert
embed_dim
%
2
==
0
omega
=
np
.
arange
(
embed_dim
//
2
,
dtype
=
np
.
float64
)
omega
/=
embed_dim
/
2.0
omega
=
1.0
/
10000
**
omega
# (D/2,)
pos
=
pos
.
reshape
(
-
1
)
# (M,)
out
=
np
.
einsum
(
"m,d->md"
,
pos
,
omega
)
# (M, D/2), outer product
emb_sin
=
np
.
sin
(
out
)
# (M, D/2)
emb_cos
=
np
.
cos
(
out
)
# (M, D/2)
return
np
.
concatenate
([
emb_sin
,
emb_cos
],
axis
=
1
)
# (M, D)
#################################################################################
# Embedding Layers for Timesteps and Class Labels #
#################################################################################
class
TimestepEmbedder
(
nn
.
Module
):
"""Embeds scalar timesteps into vector representations."""
def
__init__
(
self
,
hidden_size
,
frequency_embedding_size
=
256
,
dtype
=
None
,
device
=
None
):
super
()
.
__init__
()
self
.
mlp
=
nn
.
Sequential
(
nn
.
Linear
(
frequency_embedding_size
,
hidden_size
,
bias
=
True
,
dtype
=
dtype
,
device
=
device
),
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
True
,
dtype
=
dtype
,
device
=
device
),
)
self
.
frequency_embedding_size
=
frequency_embedding_size
@
staticmethod
def
timestep_embedding
(
t
,
dim
,
max_period
=
10000
):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
half
=
dim
//
2
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
)
.
to
(
device
=
t
.
device
)
args
=
t
[:,
None
]
.
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
if
dim
%
2
:
embedding
=
torch
.
cat
([
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
if
torch
.
is_floating_point
(
t
):
embedding
=
embedding
.
to
(
dtype
=
t
.
dtype
)
return
embedding
def
forward
(
self
,
t
,
dtype
,
**
kwargs
):
t_freq
=
self
.
timestep_embedding
(
t
,
self
.
frequency_embedding_size
)
.
to
(
dtype
)
t_emb
=
self
.
mlp
(
t_freq
)
return
t_emb
class
VectorEmbedder
(
nn
.
Module
):
"""Embeds a flat vector of dimension input_dim"""
def
__init__
(
self
,
input_dim
:
int
,
hidden_size
:
int
,
dtype
=
None
,
device
=
None
):
super
()
.
__init__
()
self
.
mlp
=
nn
.
Sequential
(
nn
.
Linear
(
input_dim
,
hidden_size
,
bias
=
True
,
dtype
=
dtype
,
device
=
device
),
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
True
,
dtype
=
dtype
,
device
=
device
),
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
mlp
(
x
)
#################################################################################
# Core DiT Model #
#################################################################################
def
split_qkv
(
qkv
,
head_dim
):
qkv
=
qkv
.
reshape
(
qkv
.
shape
[
0
],
qkv
.
shape
[
1
],
3
,
-
1
,
head_dim
)
.
movedim
(
2
,
0
)
return
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
def
optimized_attention
(
qkv
,
num_heads
):
return
attention
(
qkv
[
0
],
qkv
[
1
],
qkv
[
2
],
num_heads
)
class
SelfAttention
(
nn
.
Module
):
ATTENTION_MODES
=
(
"xformers"
,
"torch"
,
"torch-hb"
,
"math"
,
"debug"
)
def
__init__
(
self
,
dim
:
int
,
num_heads
:
int
=
8
,
qkv_bias
:
bool
=
False
,
qk_scale
:
Optional
[
float
]
=
None
,
attn_mode
:
str
=
"xformers"
,
pre_only
:
bool
=
False
,
qk_norm
:
Optional
[
str
]
=
None
,
rmsnorm
:
bool
=
False
,
dtype
=
None
,
device
=
None
,
):
super
()
.
__init__
()
self
.
num_heads
=
num_heads
self
.
head_dim
=
dim
//
num_heads
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
,
dtype
=
dtype
,
device
=
device
)
if
not
pre_only
:
self
.
proj
=
nn
.
Linear
(
dim
,
dim
,
dtype
=
dtype
,
device
=
device
)
assert
attn_mode
in
self
.
ATTENTION_MODES
self
.
attn_mode
=
attn_mode
self
.
pre_only
=
pre_only
if
qk_norm
==
"rms"
:
self
.
ln_q
=
RMSNorm
(
self
.
head_dim
,
elementwise_affine
=
True
,
eps
=
1.0e-6
,
dtype
=
dtype
,
device
=
device
)
self
.
ln_k
=
RMSNorm
(
self
.
head_dim
,
elementwise_affine
=
True
,
eps
=
1.0e-6
,
dtype
=
dtype
,
device
=
device
)
elif
qk_norm
==
"ln"
:
self
.
ln_q
=
nn
.
LayerNorm
(
self
.
head_dim
,
elementwise_affine
=
True
,
eps
=
1.0e-6
,
dtype
=
dtype
,
device
=
device
)
self
.
ln_k
=
nn
.
LayerNorm
(
self
.
head_dim
,
elementwise_affine
=
True
,
eps
=
1.0e-6
,
dtype
=
dtype
,
device
=
device
)
elif
qk_norm
is
None
:
self
.
ln_q
=
nn
.
Identity
()
self
.
ln_k
=
nn
.
Identity
()
else
:
raise
ValueError
(
qk_norm
)
def
pre_attention
(
self
,
x
:
torch
.
Tensor
):
B
,
L
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
)
q
,
k
,
v
=
split_qkv
(
qkv
,
self
.
head_dim
)
q
=
self
.
ln_q
(
q
)
.
reshape
(
q
.
shape
[
0
],
q
.
shape
[
1
],
-
1
)
k
=
self
.
ln_k
(
k
)
.
reshape
(
q
.
shape
[
0
],
q
.
shape
[
1
],
-
1
)
return
(
q
,
k
,
v
)
def
post_attention
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
not
self
.
pre_only
x
=
self
.
proj
(
x
)
return
x
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
(
q
,
k
,
v
)
=
self
.
pre_attention
(
x
)
x
=
attention
(
q
,
k
,
v
,
self
.
num_heads
)
x
=
self
.
post_attention
(
x
)
return
x
class
RMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
elementwise_affine
:
bool
=
False
,
eps
:
float
=
1e-6
,
device
=
None
,
dtype
=
None
):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
super
()
.
__init__
()
self
.
eps
=
eps
self
.
learnable_scale
=
elementwise_affine
if
self
.
learnable_scale
:
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
(
dim
,
device
=
device
,
dtype
=
dtype
))
else
:
self
.
register_parameter
(
"weight"
,
None
)
def
_norm
(
self
,
x
):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return
x
*
torch
.
rsqrt
(
x
.
pow
(
2
)
.
mean
(
-
1
,
keepdim
=
True
)
+
self
.
eps
)
def
forward
(
self
,
x
):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
x
=
self
.
_norm
(
x
)
if
self
.
learnable_scale
:
return
x
*
self
.
weight
.
to
(
device
=
x
.
device
,
dtype
=
x
.
dtype
)
else
:
return
x
class
SwiGLUFeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
hidden_dim
:
int
,
multiple_of
:
int
,
ffn_dim_multiplier
:
Optional
[
float
]
=
None
,
):
"""
Initialize the FeedForward module.
Args:
dim (int): Input dimension.
hidden_dim (int): Hidden dimension of the feedforward layer.
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
Attributes:
w1 (ColumnParallelLinear): Linear transformation for the first layer.
w2 (RowParallelLinear): Linear transformation for the second layer.
w3 (ColumnParallelLinear): Linear transformation for the third layer.
"""
super
()
.
__init__
()
hidden_dim
=
int
(
2
*
hidden_dim
/
3
)
# custom dim factor multiplier
if
ffn_dim_multiplier
is
not
None
:
hidden_dim
=
int
(
ffn_dim_multiplier
*
hidden_dim
)
hidden_dim
=
multiple_of
*
((
hidden_dim
+
multiple_of
-
1
)
//
multiple_of
)
self
.
w1
=
nn
.
Linear
(
dim
,
hidden_dim
,
bias
=
False
)
self
.
w2
=
nn
.
Linear
(
hidden_dim
,
dim
,
bias
=
False
)
self
.
w3
=
nn
.
Linear
(
dim
,
hidden_dim
,
bias
=
False
)
def
forward
(
self
,
x
):
return
self
.
w2
(
nn
.
functional
.
silu
(
self
.
w1
(
x
))
*
self
.
w3
(
x
))
class
DismantledBlock
(
nn
.
Module
):
"""A DiT block with gated adaptive layer norm (adaLN) conditioning."""
ATTENTION_MODES
=
(
"xformers"
,
"torch"
,
"torch-hb"
,
"math"
,
"debug"
)
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
mlp_ratio
:
float
=
4.0
,
attn_mode
:
str
=
"xformers"
,
qkv_bias
:
bool
=
False
,
pre_only
:
bool
=
False
,
rmsnorm
:
bool
=
False
,
scale_mod_only
:
bool
=
False
,
swiglu
:
bool
=
False
,
qk_norm
:
Optional
[
str
]
=
None
,
dtype
=
None
,
device
=
None
,
**
block_kwargs
,
):
super
()
.
__init__
()
assert
attn_mode
in
self
.
ATTENTION_MODES
if
not
rmsnorm
:
self
.
norm1
=
nn
.
LayerNorm
(
hidden_size
,
elementwise_affine
=
False
,
eps
=
1e-6
,
dtype
=
dtype
,
device
=
device
)
else
:
self
.
norm1
=
RMSNorm
(
hidden_size
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
attn
=
SelfAttention
(
dim
=
hidden_size
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
attn_mode
=
attn_mode
,
pre_only
=
pre_only
,
qk_norm
=
qk_norm
,
rmsnorm
=
rmsnorm
,
dtype
=
dtype
,
device
=
device
)
if
not
pre_only
:
if
not
rmsnorm
:
self
.
norm2
=
nn
.
LayerNorm
(
hidden_size
,
elementwise_affine
=
False
,
eps
=
1e-6
,
dtype
=
dtype
,
device
=
device
)
else
:
self
.
norm2
=
RMSNorm
(
hidden_size
,
elementwise_affine
=
False
,
eps
=
1e-6
)
mlp_hidden_dim
=
int
(
hidden_size
*
mlp_ratio
)
if
not
pre_only
:
if
not
swiglu
:
self
.
mlp
=
Mlp
(
in_features
=
hidden_size
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
nn
.
GELU
(
approximate
=
"tanh"
),
dtype
=
dtype
,
device
=
device
)
else
:
self
.
mlp
=
SwiGLUFeedForward
(
dim
=
hidden_size
,
hidden_dim
=
mlp_hidden_dim
,
multiple_of
=
256
)
self
.
scale_mod_only
=
scale_mod_only
if
not
scale_mod_only
:
n_mods
=
6
if
not
pre_only
else
2
else
:
n_mods
=
4
if
not
pre_only
else
1
self
.
adaLN_modulation
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
,
n_mods
*
hidden_size
,
bias
=
True
,
dtype
=
dtype
,
device
=
device
))
self
.
pre_only
=
pre_only
def
pre_attention
(
self
,
x
:
torch
.
Tensor
,
c
:
torch
.
Tensor
):
assert
x
is
not
None
,
"pre_attention called with None input"
if
not
self
.
pre_only
:
if
not
self
.
scale_mod_only
:
shift_msa
,
scale_msa
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
=
self
.
adaLN_modulation
(
c
)
.
chunk
(
6
,
dim
=
1
)
else
:
shift_msa
=
None
shift_mlp
=
None
scale_msa
,
gate_msa
,
scale_mlp
,
gate_mlp
=
self
.
adaLN_modulation
(
c
)
.
chunk
(
4
,
dim
=
1
)
qkv
=
self
.
attn
.
pre_attention
(
modulate
(
self
.
norm1
(
x
),
shift_msa
,
scale_msa
))
return
qkv
,
(
x
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
)
else
:
if
not
self
.
scale_mod_only
:
shift_msa
,
scale_msa
=
self
.
adaLN_modulation
(
c
)
.
chunk
(
2
,
dim
=
1
)
else
:
shift_msa
=
None
scale_msa
=
self
.
adaLN_modulation
(
c
)
qkv
=
self
.
attn
.
pre_attention
(
modulate
(
self
.
norm1
(
x
),
shift_msa
,
scale_msa
))
return
qkv
,
None
def
post_attention
(
self
,
attn
,
x
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
):
assert
not
self
.
pre_only
x
=
x
+
gate_msa
.
unsqueeze
(
1
)
*
self
.
attn
.
post_attention
(
attn
)
x
=
x
+
gate_mlp
.
unsqueeze
(
1
)
*
self
.
mlp
(
modulate
(
self
.
norm2
(
x
),
shift_mlp
,
scale_mlp
))
return
x
def
forward
(
self
,
x
:
torch
.
Tensor
,
c
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
not
self
.
pre_only
(
q
,
k
,
v
),
intermediates
=
self
.
pre_attention
(
x
,
c
)
attn
=
attention
(
q
,
k
,
v
,
self
.
attn
.
num_heads
)
return
self
.
post_attention
(
attn
,
*
intermediates
)
def
block_mixing
(
context
,
x
,
context_block
,
x_block
,
c
):
assert
context
is
not
None
,
"block_mixing called with None context"
context_qkv
,
context_intermediates
=
context_block
.
pre_attention
(
context
,
c
)
x_qkv
,
x_intermediates
=
x_block
.
pre_attention
(
x
,
c
)
o
=
[]
for
t
in
range
(
3
):
o
.
append
(
torch
.
cat
((
context_qkv
[
t
],
x_qkv
[
t
]),
dim
=
1
))
q
,
k
,
v
=
tuple
(
o
)
attn
=
attention
(
q
,
k
,
v
,
x_block
.
attn
.
num_heads
)
context_attn
,
x_attn
=
(
attn
[:,
:
context_qkv
[
0
]
.
shape
[
1
]],
attn
[:,
context_qkv
[
0
]
.
shape
[
1
]
:])
if
not
context_block
.
pre_only
:
context
=
context_block
.
post_attention
(
context_attn
,
*
context_intermediates
)
else
:
context
=
None
x
=
x_block
.
post_attention
(
x_attn
,
*
x_intermediates
)
return
context
,
x
class
JointBlock
(
nn
.
Module
):
"""just a small wrapper to serve as a fsdp unit"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
()
.
__init__
()
pre_only
=
kwargs
.
pop
(
"pre_only"
)
qk_norm
=
kwargs
.
pop
(
"qk_norm"
,
None
)
self
.
context_block
=
DismantledBlock
(
*
args
,
pre_only
=
pre_only
,
qk_norm
=
qk_norm
,
**
kwargs
)
self
.
x_block
=
DismantledBlock
(
*
args
,
pre_only
=
False
,
qk_norm
=
qk_norm
,
**
kwargs
)
def
forward
(
self
,
*
args
,
**
kwargs
):
return
block_mixing
(
*
args
,
context_block
=
self
.
context_block
,
x_block
=
self
.
x_block
,
**
kwargs
)
class
FinalLayer
(
nn
.
Module
):
"""
The final layer of DiT.
"""
def
__init__
(
self
,
hidden_size
:
int
,
patch_size
:
int
,
out_channels
:
int
,
total_out_channels
:
Optional
[
int
]
=
None
,
dtype
=
None
,
device
=
None
):
super
()
.
__init__
()
self
.
norm_final
=
nn
.
LayerNorm
(
hidden_size
,
elementwise_affine
=
False
,
eps
=
1e-6
,
dtype
=
dtype
,
device
=
device
)
self
.
linear
=
(
nn
.
Linear
(
hidden_size
,
patch_size
*
patch_size
*
out_channels
,
bias
=
True
,
dtype
=
dtype
,
device
=
device
)
if
(
total_out_channels
is
None
)
else
nn
.
Linear
(
hidden_size
,
total_out_channels
,
bias
=
True
,
dtype
=
dtype
,
device
=
device
)
)
self
.
adaLN_modulation
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
,
2
*
hidden_size
,
bias
=
True
,
dtype
=
dtype
,
device
=
device
))
def
forward
(
self
,
x
:
torch
.
Tensor
,
c
:
torch
.
Tensor
)
->
torch
.
Tensor
:
shift
,
scale
=
self
.
adaLN_modulation
(
c
)
.
chunk
(
2
,
dim
=
1
)
x
=
modulate
(
self
.
norm_final
(
x
),
shift
,
scale
)
x
=
self
.
linear
(
x
)
return
x
class
MMDiT
(
nn
.
Module
):
"""Diffusion model with a Transformer backbone."""
def
__init__
(
self
,
input_size
:
int
=
32
,
patch_size
:
int
=
2
,
in_channels
:
int
=
4
,
depth
:
int
=
28
,
mlp_ratio
:
float
=
4.0
,
learn_sigma
:
bool
=
False
,
adm_in_channels
:
Optional
[
int
]
=
None
,
context_embedder_config
:
Optional
[
Dict
]
=
None
,
register_length
:
int
=
0
,
attn_mode
:
str
=
"torch"
,
rmsnorm
:
bool
=
False
,
scale_mod_only
:
bool
=
False
,
swiglu
:
bool
=
False
,
out_channels
:
Optional
[
int
]
=
None
,
pos_embed_scaling_factor
:
Optional
[
float
]
=
None
,
pos_embed_offset
:
Optional
[
float
]
=
None
,
pos_embed_max_size
:
Optional
[
int
]
=
None
,
num_patches
=
None
,
qk_norm
:
Optional
[
str
]
=
None
,
qkv_bias
:
bool
=
True
,
dtype
=
None
,
device
=
None
,
):
super
()
.
__init__
()
print
(
f
"mmdit initializing with: {input_size=}, {patch_size=}, {in_channels=}, {depth=}, {mlp_ratio=}, {learn_sigma=}, {adm_in_channels=}, {context_embedder_config=}, {register_length=}, {attn_mode=}, {rmsnorm=}, {scale_mod_only=}, {swiglu=}, {out_channels=}, {pos_embed_scaling_factor=}, {pos_embed_offset=}, {pos_embed_max_size=}, {num_patches=}, {qk_norm=}, {qkv_bias=}, {dtype=}, {device=}"
)
self
.
dtype
=
dtype
self
.
learn_sigma
=
learn_sigma
self
.
in_channels
=
in_channels
default_out_channels
=
in_channels
*
2
if
learn_sigma
else
in_channels
self
.
out_channels
=
out_channels
if
out_channels
is
not
None
else
default_out_channels
self
.
patch_size
=
patch_size
self
.
pos_embed_scaling_factor
=
pos_embed_scaling_factor
self
.
pos_embed_offset
=
pos_embed_offset
self
.
pos_embed_max_size
=
pos_embed_max_size
# apply magic --> this defines a head_size of 64
hidden_size
=
64
*
depth
num_heads
=
depth
self
.
num_heads
=
num_heads
self
.
x_embedder
=
PatchEmbed
(
input_size
,
patch_size
,
in_channels
,
hidden_size
,
bias
=
True
,
strict_img_size
=
self
.
pos_embed_max_size
is
None
,
dtype
=
dtype
,
device
=
device
)
self
.
t_embedder
=
TimestepEmbedder
(
hidden_size
,
dtype
=
dtype
,
device
=
device
)
if
adm_in_channels
is
not
None
:
assert
isinstance
(
adm_in_channels
,
int
)
self
.
y_embedder
=
VectorEmbedder
(
adm_in_channels
,
hidden_size
,
dtype
=
dtype
,
device
=
device
)
self
.
context_embedder
=
nn
.
Identity
()
if
context_embedder_config
is
not
None
:
if
context_embedder_config
[
"target"
]
==
"torch.nn.Linear"
:
self
.
context_embedder
=
nn
.
Linear
(
**
context_embedder_config
[
"params"
],
dtype
=
dtype
,
device
=
device
)
self
.
register_length
=
register_length
if
self
.
register_length
>
0
:
self
.
register
=
nn
.
Parameter
(
torch
.
randn
(
1
,
register_length
,
hidden_size
,
dtype
=
dtype
,
device
=
device
))
# num_patches = self.x_embedder.num_patches
# Will use fixed sin-cos embedding:
# just use a buffer already
if
num_patches
is
not
None
:
self
.
register_buffer
(
"pos_embed"
,
torch
.
zeros
(
1
,
num_patches
,
hidden_size
,
dtype
=
dtype
,
device
=
device
),
)
else
:
self
.
pos_embed
=
None
self
.
joint_blocks
=
nn
.
ModuleList
(
[
JointBlock
(
hidden_size
,
num_heads
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
attn_mode
=
attn_mode
,
pre_only
=
i
==
depth
-
1
,
rmsnorm
=
rmsnorm
,
scale_mod_only
=
scale_mod_only
,
swiglu
=
swiglu
,
qk_norm
=
qk_norm
,
dtype
=
dtype
,
device
=
device
)
for
i
in
range
(
depth
)
]
)
self
.
final_layer
=
FinalLayer
(
hidden_size
,
patch_size
,
self
.
out_channels
,
dtype
=
dtype
,
device
=
device
)
def
cropped_pos_embed
(
self
,
hw
):
assert
self
.
pos_embed_max_size
is
not
None
p
=
self
.
x_embedder
.
patch_size
[
0
]
h
,
w
=
hw
# patched size
h
=
h
//
p
w
=
w
//
p
assert
h
<=
self
.
pos_embed_max_size
,
(
h
,
self
.
pos_embed_max_size
)
assert
w
<=
self
.
pos_embed_max_size
,
(
w
,
self
.
pos_embed_max_size
)
top
=
(
self
.
pos_embed_max_size
-
h
)
//
2
left
=
(
self
.
pos_embed_max_size
-
w
)
//
2
spatial_pos_embed
=
rearrange
(
self
.
pos_embed
,
"1 (h w) c -> 1 h w c"
,
h
=
self
.
pos_embed_max_size
,
w
=
self
.
pos_embed_max_size
,
)
spatial_pos_embed
=
spatial_pos_embed
[:,
top
:
top
+
h
,
left
:
left
+
w
,
:]
spatial_pos_embed
=
rearrange
(
spatial_pos_embed
,
"1 h w c -> 1 (h w) c"
)
return
spatial_pos_embed
def
unpatchify
(
self
,
x
,
hw
=
None
):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c
=
self
.
out_channels
p
=
self
.
x_embedder
.
patch_size
[
0
]
if
hw
is
None
:
h
=
w
=
int
(
x
.
shape
[
1
]
**
0.5
)
else
:
h
,
w
=
hw
h
=
h
//
p
w
=
w
//
p
assert
h
*
w
==
x
.
shape
[
1
]
x
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
h
,
w
,
p
,
p
,
c
))
x
=
torch
.
einsum
(
"nhwpqc->nchpwq"
,
x
)
imgs
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
c
,
h
*
p
,
w
*
p
))
return
imgs
def
forward_core_with_concat
(
self
,
x
:
torch
.
Tensor
,
c_mod
:
torch
.
Tensor
,
context
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
self
.
register_length
>
0
:
context
=
torch
.
cat
((
repeat
(
self
.
register
,
"1 ... -> b ..."
,
b
=
x
.
shape
[
0
]),
context
if
context
is
not
None
else
torch
.
Tensor
([])
.
type_as
(
x
)),
1
)
# context is B, L', D
# x is B, L, D
for
block
in
self
.
joint_blocks
:
context
,
x
=
block
(
context
,
x
,
c
=
c_mod
)
x
=
self
.
final_layer
(
x
,
c_mod
)
# (N, T, patch_size ** 2 * out_channels)
return
x
def
forward
(
self
,
x
:
torch
.
Tensor
,
t
:
torch
.
Tensor
,
y
:
Optional
[
torch
.
Tensor
]
=
None
,
context
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
hw
=
x
.
shape
[
-
2
:]
x
=
self
.
x_embedder
(
x
)
+
self
.
cropped_pos_embed
(
hw
)
c
=
self
.
t_embedder
(
t
,
dtype
=
x
.
dtype
)
# (N, D)
if
y
is
not
None
:
y
=
self
.
y_embedder
(
y
)
# (N, D)
c
=
c
+
y
# (N, D)
context
=
self
.
context_embedder
(
context
)
x
=
self
.
forward_core_with_concat
(
x
,
c
,
context
)
x
=
self
.
unpatchify
(
x
,
hw
=
hw
)
# (N, out_channels, H, W)
return
x
modules/models/sd3/other_impls.py
0 → 100644
View file @
a7116aa9
### This file contains impls for underlying related models (CLIP, T5, etc)
import
torch
,
math
from
torch
import
nn
from
transformers
import
CLIPTokenizer
,
T5TokenizerFast
#################################################################################################
### Core/Utility
#################################################################################################
def
attention
(
q
,
k
,
v
,
heads
,
mask
=
None
):
"""Convenience wrapper around a basic attention operation"""
b
,
_
,
dim_head
=
q
.
shape
dim_head
//=
heads
q
,
k
,
v
=
map
(
lambda
t
:
t
.
view
(
b
,
-
1
,
heads
,
dim_head
)
.
transpose
(
1
,
2
),
(
q
,
k
,
v
))
out
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
mask
,
dropout_p
=
0.0
,
is_causal
=
False
)
return
out
.
transpose
(
1
,
2
)
.
reshape
(
b
,
-
1
,
heads
*
dim_head
)
class
Mlp
(
nn
.
Module
):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
bias
=
True
,
dtype
=
None
,
device
=
None
):
super
()
.
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
,
bias
=
bias
,
dtype
=
dtype
,
device
=
device
)
self
.
act
=
act_layer
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
fc2
(
x
)
return
x
#################################################################################################
### CLIP
#################################################################################################
class
CLIPAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
heads
,
dtype
,
device
):
super
()
.
__init__
()
self
.
heads
=
heads
self
.
q_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
True
,
dtype
=
dtype
,
device
=
device
)
self
.
k_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
True
,
dtype
=
dtype
,
device
=
device
)
self
.
v_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
True
,
dtype
=
dtype
,
device
=
device
)
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
True
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
x
,
mask
=
None
):
q
=
self
.
q_proj
(
x
)
k
=
self
.
k_proj
(
x
)
v
=
self
.
v_proj
(
x
)
out
=
attention
(
q
,
k
,
v
,
self
.
heads
,
mask
)
return
self
.
out_proj
(
out
)
ACTIVATIONS
=
{
"quick_gelu"
:
lambda
a
:
a
*
torch
.
sigmoid
(
1.702
*
a
),
"gelu"
:
torch
.
nn
.
functional
.
gelu
,
}
class
CLIPLayer
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
heads
,
intermediate_size
,
intermediate_activation
,
dtype
,
device
):
super
()
.
__init__
()
self
.
layer_norm1
=
nn
.
LayerNorm
(
embed_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
self_attn
=
CLIPAttention
(
embed_dim
,
heads
,
dtype
,
device
)
self
.
layer_norm2
=
nn
.
LayerNorm
(
embed_dim
,
dtype
=
dtype
,
device
=
device
)
#self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device)
self
.
mlp
=
Mlp
(
embed_dim
,
intermediate_size
,
embed_dim
,
act_layer
=
ACTIVATIONS
[
intermediate_activation
],
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
x
,
mask
=
None
):
x
+=
self
.
self_attn
(
self
.
layer_norm1
(
x
),
mask
)
x
+=
self
.
mlp
(
self
.
layer_norm2
(
x
))
return
x
class
CLIPEncoder
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
num_layers
,
embed_dim
,
heads
,
intermediate_size
,
intermediate_activation
,
dtype
,
device
):
super
()
.
__init__
()
self
.
layers
=
torch
.
nn
.
ModuleList
([
CLIPLayer
(
embed_dim
,
heads
,
intermediate_size
,
intermediate_activation
,
dtype
,
device
)
for
i
in
range
(
num_layers
)])
def
forward
(
self
,
x
,
mask
=
None
,
intermediate_output
=
None
):
if
intermediate_output
is
not
None
:
if
intermediate_output
<
0
:
intermediate_output
=
len
(
self
.
layers
)
+
intermediate_output
intermediate
=
None
for
i
,
l
in
enumerate
(
self
.
layers
):
x
=
l
(
x
,
mask
)
if
i
==
intermediate_output
:
intermediate
=
x
.
clone
()
return
x
,
intermediate
class
CLIPEmbeddings
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
vocab_size
=
49408
,
num_positions
=
77
,
dtype
=
None
,
device
=
None
):
super
()
.
__init__
()
self
.
token_embedding
=
torch
.
nn
.
Embedding
(
vocab_size
,
embed_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
position_embedding
=
torch
.
nn
.
Embedding
(
num_positions
,
embed_dim
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
input_tokens
):
return
self
.
token_embedding
(
input_tokens
)
+
self
.
position_embedding
.
weight
class
CLIPTextModel_
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config_dict
,
dtype
,
device
):
num_layers
=
config_dict
[
"num_hidden_layers"
]
embed_dim
=
config_dict
[
"hidden_size"
]
heads
=
config_dict
[
"num_attention_heads"
]
intermediate_size
=
config_dict
[
"intermediate_size"
]
intermediate_activation
=
config_dict
[
"hidden_act"
]
super
()
.
__init__
()
self
.
embeddings
=
CLIPEmbeddings
(
embed_dim
,
dtype
=
torch
.
float32
,
device
=
device
)
self
.
encoder
=
CLIPEncoder
(
num_layers
,
embed_dim
,
heads
,
intermediate_size
,
intermediate_activation
,
dtype
,
device
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
embed_dim
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
input_tokens
,
intermediate_output
=
None
,
final_layer_norm_intermediate
=
True
):
x
=
self
.
embeddings
(
input_tokens
)
causal_mask
=
torch
.
empty
(
x
.
shape
[
1
],
x
.
shape
[
1
],
dtype
=
x
.
dtype
,
device
=
x
.
device
)
.
fill_
(
float
(
"-inf"
))
.
triu_
(
1
)
x
,
i
=
self
.
encoder
(
x
,
mask
=
causal_mask
,
intermediate_output
=
intermediate_output
)
x
=
self
.
final_layer_norm
(
x
)
if
i
is
not
None
and
final_layer_norm_intermediate
:
i
=
self
.
final_layer_norm
(
i
)
pooled_output
=
x
[
torch
.
arange
(
x
.
shape
[
0
],
device
=
x
.
device
),
input_tokens
.
to
(
dtype
=
torch
.
int
,
device
=
x
.
device
)
.
argmax
(
dim
=-
1
),]
return
x
,
i
,
pooled_output
class
CLIPTextModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config_dict
,
dtype
,
device
):
super
()
.
__init__
()
self
.
num_layers
=
config_dict
[
"num_hidden_layers"
]
self
.
text_model
=
CLIPTextModel_
(
config_dict
,
dtype
,
device
)
embed_dim
=
config_dict
[
"hidden_size"
]
self
.
text_projection
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
text_projection
.
weight
.
copy_
(
torch
.
eye
(
embed_dim
))
self
.
dtype
=
dtype
def
get_input_embeddings
(
self
):
return
self
.
text_model
.
embeddings
.
token_embedding
def
set_input_embeddings
(
self
,
embeddings
):
self
.
text_model
.
embeddings
.
token_embedding
=
embeddings
def
forward
(
self
,
*
args
,
**
kwargs
):
x
=
self
.
text_model
(
*
args
,
**
kwargs
)
out
=
self
.
text_projection
(
x
[
2
])
return
(
x
[
0
],
x
[
1
],
out
,
x
[
2
])
class
SDTokenizer
:
def
__init__
(
self
,
max_length
=
77
,
pad_with_end
=
True
,
tokenizer
=
None
,
has_start_token
=
True
,
pad_to_max_length
=
True
,
min_length
=
None
):
self
.
tokenizer
=
tokenizer
self
.
max_length
=
max_length
self
.
min_length
=
min_length
empty
=
self
.
tokenizer
(
''
)[
"input_ids"
]
if
has_start_token
:
self
.
tokens_start
=
1
self
.
start_token
=
empty
[
0
]
self
.
end_token
=
empty
[
1
]
else
:
self
.
tokens_start
=
0
self
.
start_token
=
None
self
.
end_token
=
empty
[
0
]
self
.
pad_with_end
=
pad_with_end
self
.
pad_to_max_length
=
pad_to_max_length
vocab
=
self
.
tokenizer
.
get_vocab
()
self
.
inv_vocab
=
{
v
:
k
for
k
,
v
in
vocab
.
items
()}
self
.
max_word_length
=
8
def
tokenize_with_weights
(
self
,
text
:
str
):
"""Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3."""
if
self
.
pad_with_end
:
pad_token
=
self
.
end_token
else
:
pad_token
=
0
batch
=
[]
if
self
.
start_token
is
not
None
:
batch
.
append
((
self
.
start_token
,
1.0
))
to_tokenize
=
text
.
replace
(
"
\n
"
,
" "
)
.
split
(
' '
)
to_tokenize
=
[
x
for
x
in
to_tokenize
if
x
!=
""
]
for
word
in
to_tokenize
:
batch
.
extend
([(
t
,
1
)
for
t
in
self
.
tokenizer
(
word
)[
"input_ids"
][
self
.
tokens_start
:
-
1
]])
batch
.
append
((
self
.
end_token
,
1.0
))
if
self
.
pad_to_max_length
:
batch
.
extend
([(
pad_token
,
1.0
)]
*
(
self
.
max_length
-
len
(
batch
)))
if
self
.
min_length
is
not
None
and
len
(
batch
)
<
self
.
min_length
:
batch
.
extend
([(
pad_token
,
1.0
)]
*
(
self
.
min_length
-
len
(
batch
)))
return
[
batch
]
class
SDXLClipGTokenizer
(
SDTokenizer
):
def
__init__
(
self
,
tokenizer
):
super
()
.
__init__
(
pad_with_end
=
False
,
tokenizer
=
tokenizer
)
class
SD3Tokenizer
:
def
__init__
(
self
):
clip_tokenizer
=
CLIPTokenizer
.
from_pretrained
(
"openai/clip-vit-large-patch14"
)
self
.
clip_l
=
SDTokenizer
(
tokenizer
=
clip_tokenizer
)
self
.
clip_g
=
SDXLClipGTokenizer
(
clip_tokenizer
)
self
.
t5xxl
=
T5XXLTokenizer
()
def
tokenize_with_weights
(
self
,
text
:
str
):
out
=
{}
out
[
"g"
]
=
self
.
clip_g
.
tokenize_with_weights
(
text
)
out
[
"l"
]
=
self
.
clip_l
.
tokenize_with_weights
(
text
)
out
[
"t5xxl"
]
=
self
.
t5xxl
.
tokenize_with_weights
(
text
)
return
out
class
ClipTokenWeightEncoder
:
def
encode_token_weights
(
self
,
token_weight_pairs
):
tokens
=
list
(
map
(
lambda
a
:
a
[
0
],
token_weight_pairs
[
0
]))
out
,
pooled
=
self
([
tokens
])
if
pooled
is
not
None
:
first_pooled
=
pooled
[
0
:
1
]
.
cpu
()
else
:
first_pooled
=
pooled
output
=
[
out
[
0
:
1
]]
return
torch
.
cat
(
output
,
dim
=-
2
)
.
cpu
(),
first_pooled
class
SDClipModel
(
torch
.
nn
.
Module
,
ClipTokenWeightEncoder
):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS
=
[
"last"
,
"pooled"
,
"hidden"
]
def
__init__
(
self
,
device
=
"cpu"
,
max_length
=
77
,
layer
=
"last"
,
layer_idx
=
None
,
textmodel_json_config
=
None
,
dtype
=
None
,
model_class
=
CLIPTextModel
,
special_tokens
=
{
"start"
:
49406
,
"end"
:
49407
,
"pad"
:
49407
},
layer_norm_hidden_state
=
True
,
return_projected_pooled
=
True
):
super
()
.
__init__
()
assert
layer
in
self
.
LAYERS
self
.
transformer
=
model_class
(
textmodel_json_config
,
dtype
,
device
)
self
.
num_layers
=
self
.
transformer
.
num_layers
self
.
max_length
=
max_length
self
.
transformer
=
self
.
transformer
.
eval
()
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
self
.
layer
=
layer
self
.
layer_idx
=
None
self
.
special_tokens
=
special_tokens
self
.
logit_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
4.6055
))
self
.
layer_norm_hidden_state
=
layer_norm_hidden_state
self
.
return_projected_pooled
=
return_projected_pooled
if
layer
==
"hidden"
:
assert
layer_idx
is
not
None
assert
abs
(
layer_idx
)
<
self
.
num_layers
self
.
set_clip_options
({
"layer"
:
layer_idx
})
self
.
options_default
=
(
self
.
layer
,
self
.
layer_idx
,
self
.
return_projected_pooled
)
def
set_clip_options
(
self
,
options
):
layer_idx
=
options
.
get
(
"layer"
,
self
.
layer_idx
)
self
.
return_projected_pooled
=
options
.
get
(
"projected_pooled"
,
self
.
return_projected_pooled
)
if
layer_idx
is
None
or
abs
(
layer_idx
)
>
self
.
num_layers
:
self
.
layer
=
"last"
else
:
self
.
layer
=
"hidden"
self
.
layer_idx
=
layer_idx
def
forward
(
self
,
tokens
):
backup_embeds
=
self
.
transformer
.
get_input_embeddings
()
device
=
backup_embeds
.
weight
.
device
tokens
=
torch
.
LongTensor
(
tokens
)
.
to
(
device
)
outputs
=
self
.
transformer
(
tokens
,
intermediate_output
=
self
.
layer_idx
,
final_layer_norm_intermediate
=
self
.
layer_norm_hidden_state
)
self
.
transformer
.
set_input_embeddings
(
backup_embeds
)
if
self
.
layer
==
"last"
:
z
=
outputs
[
0
]
else
:
z
=
outputs
[
1
]
pooled_output
=
None
if
len
(
outputs
)
>=
3
:
if
not
self
.
return_projected_pooled
and
len
(
outputs
)
>=
4
and
outputs
[
3
]
is
not
None
:
pooled_output
=
outputs
[
3
]
.
float
()
elif
outputs
[
2
]
is
not
None
:
pooled_output
=
outputs
[
2
]
.
float
()
return
z
.
float
(),
pooled_output
class
SDXLClipG
(
SDClipModel
):
"""Wraps the CLIP-G model into the SD-CLIP-Model interface"""
def
__init__
(
self
,
config
,
device
=
"cpu"
,
layer
=
"penultimate"
,
layer_idx
=
None
,
dtype
=
None
):
if
layer
==
"penultimate"
:
layer
=
"hidden"
layer_idx
=-
2
super
()
.
__init__
(
device
=
device
,
layer
=
layer
,
layer_idx
=
layer_idx
,
textmodel_json_config
=
config
,
dtype
=
dtype
,
special_tokens
=
{
"start"
:
49406
,
"end"
:
49407
,
"pad"
:
0
},
layer_norm_hidden_state
=
False
)
class
T5XXLModel
(
SDClipModel
):
"""Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience"""
def
__init__
(
self
,
config
,
device
=
"cpu"
,
layer
=
"last"
,
layer_idx
=
None
,
dtype
=
None
):
super
()
.
__init__
(
device
=
device
,
layer
=
layer
,
layer_idx
=
layer_idx
,
textmodel_json_config
=
config
,
dtype
=
dtype
,
special_tokens
=
{
"end"
:
1
,
"pad"
:
0
},
model_class
=
T5
)
#################################################################################################
### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
#################################################################################################
class
T5XXLTokenizer
(
SDTokenizer
):
"""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
def
__init__
(
self
):
super
()
.
__init__
(
pad_with_end
=
False
,
tokenizer
=
T5TokenizerFast
.
from_pretrained
(
"google/t5-v1_1-xxl"
),
has_start_token
=
False
,
pad_to_max_length
=
False
,
max_length
=
99999999
,
min_length
=
77
)
class
T5LayerNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
,
dtype
=
None
,
device
=
None
):
super
()
.
__init__
()
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
hidden_size
,
dtype
=
dtype
,
device
=
device
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
x
):
variance
=
x
.
pow
(
2
)
.
mean
(
-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
return
self
.
weight
.
to
(
device
=
x
.
device
,
dtype
=
x
.
dtype
)
*
x
class
T5DenseGatedActDense
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model_dim
,
ff_dim
,
dtype
,
device
):
super
()
.
__init__
()
self
.
wi_0
=
nn
.
Linear
(
model_dim
,
ff_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
wi_1
=
nn
.
Linear
(
model_dim
,
ff_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
wo
=
nn
.
Linear
(
ff_dim
,
model_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
x
):
hidden_gelu
=
torch
.
nn
.
functional
.
gelu
(
self
.
wi_0
(
x
),
approximate
=
"tanh"
)
hidden_linear
=
self
.
wi_1
(
x
)
x
=
hidden_gelu
*
hidden_linear
x
=
self
.
wo
(
x
)
return
x
class
T5LayerFF
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model_dim
,
ff_dim
,
dtype
,
device
):
super
()
.
__init__
()
self
.
DenseReluDense
=
T5DenseGatedActDense
(
model_dim
,
ff_dim
,
dtype
,
device
)
self
.
layer_norm
=
T5LayerNorm
(
model_dim
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
x
):
forwarded_states
=
self
.
layer_norm
(
x
)
forwarded_states
=
self
.
DenseReluDense
(
forwarded_states
)
x
+=
forwarded_states
return
x
class
T5Attention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model_dim
,
inner_dim
,
num_heads
,
relative_attention_bias
,
dtype
,
device
):
super
()
.
__init__
()
# Mesh TensorFlow initialization to avoid scaling before softmax
self
.
q
=
nn
.
Linear
(
model_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
k
=
nn
.
Linear
(
model_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
v
=
nn
.
Linear
(
model_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
o
=
nn
.
Linear
(
inner_dim
,
model_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
num_heads
=
num_heads
self
.
relative_attention_bias
=
None
if
relative_attention_bias
:
self
.
relative_attention_num_buckets
=
32
self
.
relative_attention_max_distance
=
128
self
.
relative_attention_bias
=
torch
.
nn
.
Embedding
(
self
.
relative_attention_num_buckets
,
self
.
num_heads
,
device
=
device
)
@
staticmethod
def
_relative_position_bucket
(
relative_position
,
bidirectional
=
True
,
num_buckets
=
32
,
max_distance
=
128
):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets
=
0
if
bidirectional
:
num_buckets
//=
2
relative_buckets
+=
(
relative_position
>
0
)
.
to
(
torch
.
long
)
*
num_buckets
relative_position
=
torch
.
abs
(
relative_position
)
else
:
relative_position
=
-
torch
.
min
(
relative_position
,
torch
.
zeros_like
(
relative_position
))
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact
=
num_buckets
//
2
is_small
=
relative_position
<
max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
relative_position_if_large
=
max_exact
+
(
torch
.
log
(
relative_position
.
float
()
/
max_exact
)
/
math
.
log
(
max_distance
/
max_exact
)
*
(
num_buckets
-
max_exact
)
)
.
to
(
torch
.
long
)
relative_position_if_large
=
torch
.
min
(
relative_position_if_large
,
torch
.
full_like
(
relative_position_if_large
,
num_buckets
-
1
))
relative_buckets
+=
torch
.
where
(
is_small
,
relative_position
,
relative_position_if_large
)
return
relative_buckets
def
compute_bias
(
self
,
query_length
,
key_length
,
device
):
"""Compute binned relative position bias"""
context_position
=
torch
.
arange
(
query_length
,
dtype
=
torch
.
long
,
device
=
device
)[:,
None
]
memory_position
=
torch
.
arange
(
key_length
,
dtype
=
torch
.
long
,
device
=
device
)[
None
,
:]
relative_position
=
memory_position
-
context_position
# shape (query_length, key_length)
relative_position_bucket
=
self
.
_relative_position_bucket
(
relative_position
,
# shape (query_length, key_length)
bidirectional
=
True
,
num_buckets
=
self
.
relative_attention_num_buckets
,
max_distance
=
self
.
relative_attention_max_distance
,
)
values
=
self
.
relative_attention_bias
(
relative_position_bucket
)
# shape (query_length, key_length, num_heads)
values
=
values
.
permute
([
2
,
0
,
1
])
.
unsqueeze
(
0
)
# shape (1, num_heads, query_length, key_length)
return
values
def
forward
(
self
,
x
,
past_bias
=
None
):
q
=
self
.
q
(
x
)
k
=
self
.
k
(
x
)
v
=
self
.
v
(
x
)
if
self
.
relative_attention_bias
is
not
None
:
past_bias
=
self
.
compute_bias
(
x
.
shape
[
1
],
x
.
shape
[
1
],
x
.
device
)
if
past_bias
is
not
None
:
mask
=
past_bias
out
=
attention
(
q
,
k
*
((
k
.
shape
[
-
1
]
/
self
.
num_heads
)
**
0.5
),
v
,
self
.
num_heads
,
mask
)
return
self
.
o
(
out
),
past_bias
class
T5LayerSelfAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model_dim
,
inner_dim
,
ff_dim
,
num_heads
,
relative_attention_bias
,
dtype
,
device
):
super
()
.
__init__
()
self
.
SelfAttention
=
T5Attention
(
model_dim
,
inner_dim
,
num_heads
,
relative_attention_bias
,
dtype
,
device
)
self
.
layer_norm
=
T5LayerNorm
(
model_dim
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
x
,
past_bias
=
None
):
output
,
past_bias
=
self
.
SelfAttention
(
self
.
layer_norm
(
x
),
past_bias
=
past_bias
)
x
+=
output
return
x
,
past_bias
class
T5Block
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model_dim
,
inner_dim
,
ff_dim
,
num_heads
,
relative_attention_bias
,
dtype
,
device
):
super
()
.
__init__
()
self
.
layer
=
torch
.
nn
.
ModuleList
()
self
.
layer
.
append
(
T5LayerSelfAttention
(
model_dim
,
inner_dim
,
ff_dim
,
num_heads
,
relative_attention_bias
,
dtype
,
device
))
self
.
layer
.
append
(
T5LayerFF
(
model_dim
,
ff_dim
,
dtype
,
device
))
def
forward
(
self
,
x
,
past_bias
=
None
):
x
,
past_bias
=
self
.
layer
[
0
](
x
,
past_bias
)
x
=
self
.
layer
[
-
1
](
x
)
return
x
,
past_bias
class
T5Stack
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
num_layers
,
model_dim
,
inner_dim
,
ff_dim
,
num_heads
,
vocab_size
,
dtype
,
device
):
super
()
.
__init__
()
self
.
embed_tokens
=
torch
.
nn
.
Embedding
(
vocab_size
,
model_dim
,
device
=
device
)
self
.
block
=
torch
.
nn
.
ModuleList
([
T5Block
(
model_dim
,
inner_dim
,
ff_dim
,
num_heads
,
relative_attention_bias
=
(
i
==
0
),
dtype
=
dtype
,
device
=
device
)
for
i
in
range
(
num_layers
)])
self
.
final_layer_norm
=
T5LayerNorm
(
model_dim
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
input_ids
,
intermediate_output
=
None
,
final_layer_norm_intermediate
=
True
):
intermediate
=
None
x
=
self
.
embed_tokens
(
input_ids
)
past_bias
=
None
for
i
,
l
in
enumerate
(
self
.
block
):
x
,
past_bias
=
l
(
x
,
past_bias
)
if
i
==
intermediate_output
:
intermediate
=
x
.
clone
()
x
=
self
.
final_layer_norm
(
x
)
if
intermediate
is
not
None
and
final_layer_norm_intermediate
:
intermediate
=
self
.
final_layer_norm
(
intermediate
)
return
x
,
intermediate
class
T5
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config_dict
,
dtype
,
device
):
super
()
.
__init__
()
self
.
num_layers
=
config_dict
[
"num_layers"
]
self
.
encoder
=
T5Stack
(
self
.
num_layers
,
config_dict
[
"d_model"
],
config_dict
[
"d_model"
],
config_dict
[
"d_ff"
],
config_dict
[
"num_heads"
],
config_dict
[
"vocab_size"
],
dtype
,
device
)
self
.
dtype
=
dtype
def
get_input_embeddings
(
self
):
return
self
.
encoder
.
embed_tokens
def
set_input_embeddings
(
self
,
embeddings
):
self
.
encoder
.
embed_tokens
=
embeddings
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
encoder
(
*
args
,
**
kwargs
)
modules/models/sd3/sd3_impls.py
0 → 100644
View file @
a7116aa9
### Impls of the SD3 core diffusion model and VAE
import
torch
,
math
,
einops
from
mmdit
import
MMDiT
from
PIL
import
Image
#################################################################################################
### MMDiT Model Wrapping
#################################################################################################
class
ModelSamplingDiscreteFlow
(
torch
.
nn
.
Module
):
"""Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models"""
def
__init__
(
self
,
shift
=
1.0
):
super
()
.
__init__
()
self
.
shift
=
shift
timesteps
=
1000
ts
=
self
.
sigma
(
torch
.
arange
(
1
,
timesteps
+
1
,
1
))
self
.
register_buffer
(
'sigmas'
,
ts
)
@
property
def
sigma_min
(
self
):
return
self
.
sigmas
[
0
]
@
property
def
sigma_max
(
self
):
return
self
.
sigmas
[
-
1
]
def
timestep
(
self
,
sigma
):
return
sigma
*
1000
def
sigma
(
self
,
timestep
:
torch
.
Tensor
):
timestep
=
timestep
/
1000.0
if
self
.
shift
==
1.0
:
return
timestep
return
self
.
shift
*
timestep
/
(
1
+
(
self
.
shift
-
1
)
*
timestep
)
def
calculate_denoised
(
self
,
sigma
,
model_output
,
model_input
):
sigma
=
sigma
.
view
(
sigma
.
shape
[:
1
]
+
(
1
,)
*
(
model_output
.
ndim
-
1
))
return
model_input
-
model_output
*
sigma
def
noise_scaling
(
self
,
sigma
,
noise
,
latent_image
,
max_denoise
=
False
):
return
sigma
*
noise
+
(
1.0
-
sigma
)
*
latent_image
class
BaseModel
(
torch
.
nn
.
Module
):
"""Wrapper around the core MM-DiT model"""
def
__init__
(
self
,
shift
=
1.0
,
device
=
None
,
dtype
=
torch
.
float32
,
file
=
None
,
prefix
=
""
):
super
()
.
__init__
()
# Important configuration values can be quickly determined by checking shapes in the source file
# Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)
patch_size
=
file
.
get_tensor
(
f
"{prefix}x_embedder.proj.weight"
)
.
shape
[
2
]
depth
=
file
.
get_tensor
(
f
"{prefix}x_embedder.proj.weight"
)
.
shape
[
0
]
//
64
num_patches
=
file
.
get_tensor
(
f
"{prefix}pos_embed"
)
.
shape
[
1
]
pos_embed_max_size
=
round
(
math
.
sqrt
(
num_patches
))
adm_in_channels
=
file
.
get_tensor
(
f
"{prefix}y_embedder.mlp.0.weight"
)
.
shape
[
1
]
context_shape
=
file
.
get_tensor
(
f
"{prefix}context_embedder.weight"
)
.
shape
context_embedder_config
=
{
"target"
:
"torch.nn.Linear"
,
"params"
:
{
"in_features"
:
context_shape
[
1
],
"out_features"
:
context_shape
[
0
]
}
}
self
.
diffusion_model
=
MMDiT
(
input_size
=
None
,
pos_embed_scaling_factor
=
None
,
pos_embed_offset
=
None
,
pos_embed_max_size
=
pos_embed_max_size
,
patch_size
=
patch_size
,
in_channels
=
16
,
depth
=
depth
,
num_patches
=
num_patches
,
adm_in_channels
=
adm_in_channels
,
context_embedder_config
=
context_embedder_config
,
device
=
device
,
dtype
=
dtype
)
self
.
model_sampling
=
ModelSamplingDiscreteFlow
(
shift
=
shift
)
def
apply_model
(
self
,
x
,
sigma
,
c_crossattn
=
None
,
y
=
None
):
dtype
=
self
.
get_dtype
()
timestep
=
self
.
model_sampling
.
timestep
(
sigma
)
.
float
()
model_output
=
self
.
diffusion_model
(
x
.
to
(
dtype
),
timestep
,
context
=
c_crossattn
.
to
(
dtype
),
y
=
y
.
to
(
dtype
))
.
float
()
return
self
.
model_sampling
.
calculate_denoised
(
sigma
,
model_output
,
x
)
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
apply_model
(
*
args
,
**
kwargs
)
def
get_dtype
(
self
):
return
self
.
diffusion_model
.
dtype
class
CFGDenoiser
(
torch
.
nn
.
Module
):
"""Helper for applying CFG Scaling to diffusion outputs"""
def
__init__
(
self
,
model
):
super
()
.
__init__
()
self
.
model
=
model
def
forward
(
self
,
x
,
timestep
,
cond
,
uncond
,
cond_scale
):
# Run cond and uncond in a batch together
batched
=
self
.
model
.
apply_model
(
torch
.
cat
([
x
,
x
]),
torch
.
cat
([
timestep
,
timestep
]),
c_crossattn
=
torch
.
cat
([
cond
[
"c_crossattn"
],
uncond
[
"c_crossattn"
]]),
y
=
torch
.
cat
([
cond
[
"y"
],
uncond
[
"y"
]]))
# Then split and apply CFG Scaling
pos_out
,
neg_out
=
batched
.
chunk
(
2
)
scaled
=
neg_out
+
(
pos_out
-
neg_out
)
*
cond_scale
return
scaled
class
SD3LatentFormat
:
"""Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift"""
def
__init__
(
self
):
self
.
scale_factor
=
1.5305
self
.
shift_factor
=
0.0609
def
process_in
(
self
,
latent
):
return
(
latent
-
self
.
shift_factor
)
*
self
.
scale_factor
def
process_out
(
self
,
latent
):
return
(
latent
/
self
.
scale_factor
)
+
self
.
shift_factor
def
decode_latent_to_preview
(
self
,
x0
):
"""Quick RGB approximate preview of sd3 latents"""
factors
=
torch
.
tensor
([
[
-
0.0645
,
0.0177
,
0.1052
],
[
0.0028
,
0.0312
,
0.0650
],
[
0.1848
,
0.0762
,
0.0360
],
[
0.0944
,
0.0360
,
0.0889
],
[
0.0897
,
0.0506
,
-
0.0364
],
[
-
0.0020
,
0.1203
,
0.0284
],
[
0.0855
,
0.0118
,
0.0283
],
[
-
0.0539
,
0.0658
,
0.1047
],
[
-
0.0057
,
0.0116
,
0.0700
],
[
-
0.0412
,
0.0281
,
-
0.0039
],
[
0.1106
,
0.1171
,
0.1220
],
[
-
0.0248
,
0.0682
,
-
0.0481
],
[
0.0815
,
0.0846
,
0.1207
],
[
-
0.0120
,
-
0.0055
,
-
0.0867
],
[
-
0.0749
,
-
0.0634
,
-
0.0456
],
[
-
0.1418
,
-
0.1457
,
-
0.1259
]
],
device
=
"cpu"
)
latent_image
=
x0
[
0
]
.
permute
(
1
,
2
,
0
)
.
cpu
()
@
factors
latents_ubyte
=
(((
latent_image
+
1
)
/
2
)
.
clamp
(
0
,
1
)
# change scale from -1..1 to 0..1
.
mul
(
0xFF
)
# to 0..255
.
byte
())
.
cpu
()
return
Image
.
fromarray
(
latents_ubyte
.
numpy
())
#################################################################################################
### K-Diffusion Sampling
#################################################################################################
def
append_dims
(
x
,
target_dims
):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append
=
target_dims
-
x
.
ndim
return
x
[(
...
,)
+
(
None
,)
*
dims_to_append
]
def
to_d
(
x
,
sigma
,
denoised
):
"""Converts a denoiser output to a Karras ODE derivative."""
return
(
x
-
denoised
)
/
append_dims
(
sigma
,
x
.
ndim
)
@
torch
.
no_grad
()
@
torch
.
autocast
(
"cuda"
,
dtype
=
torch
.
float16
)
def
sample_euler
(
model
,
x
,
sigmas
,
extra_args
=
None
):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args
=
{}
if
extra_args
is
None
else
extra_args
s_in
=
x
.
new_ones
([
x
.
shape
[
0
]])
for
i
in
range
(
len
(
sigmas
)
-
1
):
sigma_hat
=
sigmas
[
i
]
denoised
=
model
(
x
,
sigma_hat
*
s_in
,
**
extra_args
)
d
=
to_d
(
x
,
sigma_hat
,
denoised
)
dt
=
sigmas
[
i
+
1
]
-
sigma_hat
# Euler method
x
=
x
+
d
*
dt
return
x
#################################################################################################
### VAE
#################################################################################################
def
Normalize
(
in_channels
,
num_groups
=
32
,
dtype
=
torch
.
float32
,
device
=
None
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
,
dtype
=
dtype
,
device
=
device
)
class
ResnetBlock
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
dtype
=
torch
.
float32
,
device
=
None
):
super
()
.
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
norm1
=
Normalize
(
in_channels
,
dtype
=
dtype
,
device
=
device
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
dtype
=
dtype
,
device
=
device
)
self
.
norm2
=
Normalize
(
out_channels
,
dtype
=
dtype
,
device
=
device
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
dtype
=
dtype
,
device
=
device
)
if
self
.
in_channels
!=
self
.
out_channels
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
dtype
=
dtype
,
device
=
device
)
else
:
self
.
nin_shortcut
=
None
self
.
swish
=
torch
.
nn
.
SiLU
(
inplace
=
True
)
def
forward
(
self
,
x
):
hidden
=
x
hidden
=
self
.
norm1
(
hidden
)
hidden
=
self
.
swish
(
hidden
)
hidden
=
self
.
conv1
(
hidden
)
hidden
=
self
.
norm2
(
hidden
)
hidden
=
self
.
swish
(
hidden
)
hidden
=
self
.
conv2
(
hidden
)
if
self
.
in_channels
!=
self
.
out_channels
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
hidden
class
AttnBlock
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
dtype
=
torch
.
float32
,
device
=
None
):
super
()
.
__init__
()
self
.
norm
=
Normalize
(
in_channels
,
dtype
=
dtype
,
device
=
device
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
dtype
=
dtype
,
device
=
device
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
dtype
=
dtype
,
device
=
device
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
dtype
=
dtype
,
device
=
device
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
x
):
hidden
=
self
.
norm
(
x
)
q
=
self
.
q
(
hidden
)
k
=
self
.
k
(
hidden
)
v
=
self
.
v
(
hidden
)
b
,
c
,
h
,
w
=
q
.
shape
q
,
k
,
v
=
map
(
lambda
x
:
einops
.
rearrange
(
x
,
"b c h w -> b 1 (h w) c"
)
.
contiguous
(),
(
q
,
k
,
v
))
hidden
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
q
,
k
,
v
)
# scale is dim ** -0.5 per default
hidden
=
einops
.
rearrange
(
hidden
,
"b 1 (h w) c -> b c h w"
,
h
=
h
,
w
=
w
,
c
=
c
,
b
=
b
)
hidden
=
self
.
proj_out
(
hidden
)
return
x
+
hidden
class
Downsample
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
dtype
=
torch
.
float32
,
device
=
None
):
super
()
.
__init__
()
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
0
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
x
):
pad
=
(
0
,
1
,
0
,
1
)
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
x
=
self
.
conv
(
x
)
return
x
class
Upsample
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
dtype
=
torch
.
float32
,
device
=
None
):
super
()
.
__init__
()
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
x
):
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
x
=
self
.
conv
(
x
)
return
x
class
VAEEncoder
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
ch
=
128
,
ch_mult
=
(
1
,
2
,
4
,
4
),
num_res_blocks
=
2
,
in_channels
=
3
,
z_channels
=
16
,
dtype
=
torch
.
float32
,
device
=
None
):
super
()
.
__init__
()
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
# downsampling
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
in_channels
,
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
dtype
=
dtype
,
device
=
device
)
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
self
.
in_ch_mult
=
in_ch_mult
self
.
down
=
torch
.
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
torch
.
nn
.
ModuleList
()
attn
=
torch
.
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
num_res_blocks
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
dtype
=
dtype
,
device
=
device
))
block_in
=
block_out
down
=
torch
.
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
dtype
=
dtype
,
device
=
device
)
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
torch
.
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
dtype
=
dtype
,
device
=
device
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
,
dtype
=
dtype
,
device
=
device
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
dtype
=
dtype
,
device
=
device
)
# end
self
.
norm_out
=
Normalize
(
block_in
,
dtype
=
dtype
,
device
=
device
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
2
*
z_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
dtype
=
dtype
,
device
=
device
)
self
.
swish
=
torch
.
nn
.
SiLU
(
inplace
=
True
)
def
forward
(
self
,
x
):
# downsampling
hs
=
[
self
.
conv_in
(
x
)]
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
]
.
block
[
i_block
](
hs
[
-
1
])
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
]
.
downsample
(
hs
[
-
1
]))
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
.
block_1
(
h
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
)
# end
h
=
self
.
norm_out
(
h
)
h
=
self
.
swish
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
VAEDecoder
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
ch
=
128
,
out_ch
=
3
,
ch_mult
=
(
1
,
2
,
4
,
4
),
num_res_blocks
=
2
,
resolution
=
256
,
z_channels
=
16
,
dtype
=
torch
.
float32
,
device
=
None
):
super
()
.
__init__
()
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
block_in
=
ch
*
ch_mult
[
self
.
num_resolutions
-
1
]
curr_res
=
resolution
//
2
**
(
self
.
num_resolutions
-
1
)
# z to block_in
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
z_channels
,
block_in
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
dtype
=
dtype
,
device
=
device
)
# middle
self
.
mid
=
torch
.
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
dtype
=
dtype
,
device
=
device
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
,
dtype
=
dtype
,
device
=
device
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
dtype
=
dtype
,
device
=
device
)
# upsampling
self
.
up
=
torch
.
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
torch
.
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
dtype
=
dtype
,
device
=
device
))
block_in
=
block_out
up
=
torch
.
nn
.
Module
()
up
.
block
=
block
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
,
dtype
=
dtype
,
device
=
device
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
Normalize
(
block_in
,
dtype
=
dtype
,
device
=
device
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
dtype
=
dtype
,
device
=
device
)
self
.
swish
=
torch
.
nn
.
SiLU
(
inplace
=
True
)
def
forward
(
self
,
z
):
# z to block_in
hidden
=
self
.
conv_in
(
z
)
# middle
hidden
=
self
.
mid
.
block_1
(
hidden
)
hidden
=
self
.
mid
.
attn_1
(
hidden
)
hidden
=
self
.
mid
.
block_2
(
hidden
)
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
hidden
=
self
.
up
[
i_level
]
.
block
[
i_block
](
hidden
)
if
i_level
!=
0
:
hidden
=
self
.
up
[
i_level
]
.
upsample
(
hidden
)
# end
hidden
=
self
.
norm_out
(
hidden
)
hidden
=
self
.
swish
(
hidden
)
hidden
=
self
.
conv_out
(
hidden
)
return
hidden
class
SDVAE
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dtype
=
torch
.
float32
,
device
=
None
):
super
()
.
__init__
()
self
.
encoder
=
VAEEncoder
(
dtype
=
dtype
,
device
=
device
)
self
.
decoder
=
VAEDecoder
(
dtype
=
dtype
,
device
=
device
)
@
torch
.
autocast
(
"cuda"
,
dtype
=
torch
.
float16
)
def
decode
(
self
,
latent
):
return
self
.
decoder
(
latent
)
@
torch
.
autocast
(
"cuda"
,
dtype
=
torch
.
float16
)
def
encode
(
self
,
image
):
hidden
=
self
.
encoder
(
image
)
mean
,
logvar
=
torch
.
chunk
(
hidden
,
2
,
dim
=
1
)
logvar
=
torch
.
clamp
(
logvar
,
-
30.0
,
20.0
)
std
=
torch
.
exp
(
0.5
*
logvar
)
return
mean
+
std
*
torch
.
randn_like
(
mean
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment