Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
H
Hydra Node Http
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Hydra Node Http
Commits
b43d6ea1
Commit
b43d6ea1
authored
Aug 19, 2022
by
kurumuz
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
k-diffusion samplers
parent
56787cb0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
102 additions
and
16 deletions
+102
-16
hydra_node/models.py
hydra_node/models.py
+102
-16
No files found.
hydra_node/models.py
View file @
b43d6ea1
...
@@ -14,6 +14,7 @@ from ldm.models.diffusion.ddim import DDIMSampler
...
@@ -14,6 +14,7 @@ from ldm.models.diffusion.ddim import DDIMSampler
from
ldm.models.diffusion.plms
import
PLMSSampler
from
ldm.models.diffusion.plms
import
PLMSSampler
import
time
import
time
from
PIL
import
Image
from
PIL
import
Image
import
k_diffusion
as
K
def
pil_upscale
(
image
,
scale
=
1
):
def
pil_upscale
(
image
,
scale
=
1
):
device
=
image
.
device
device
=
image
.
device
...
@@ -58,6 +59,52 @@ def prompt_mixing(model, prompt_body, batch_size):
...
@@ -58,6 +59,52 @@ def prompt_mixing(model, prompt_body, batch_size):
else
:
else
:
return
fix_batch
(
model
.
get_learned_conditioning
([
prompt_body
]),
batch_size
)
return
fix_batch
(
model
.
get_learned_conditioning
([
prompt_body
]),
batch_size
)
@
torch
.
no_grad
()
#@torch.autocast("cuda", enabled=True, dtype=torch.float16)
def
encode_image
(
image
,
model
):
if
isinstance
(
image
,
Image
.
Image
):
image
=
np
.
asarray
(
image
)
image
=
torch
.
from_numpy
(
image
)
.
clone
()
if
isinstance
(
image
,
np
.
ndarray
):
image
=
torch
.
from_numpy
(
image
)
#gets image as numpy array and returns as tensor
def
preprocess_vqgan
(
x
):
x
=
x
/
255.0
x
=
2.
*
x
-
1.
return
x
image
=
image
.
permute
(
2
,
0
,
1
)
.
unsqueeze
(
0
)
.
float
()
.
cuda
()
image
=
preprocess_vqgan
(
image
)
image
=
model
.
encode
(
image
)
.
sample
()
return
image
@
torch
.
no_grad
()
def
decode_image
(
image
,
model
):
def
custom_to_pil
(
x
):
x
=
x
.
detach
()
.
float
()
.
cpu
()
x
=
torch
.
clamp
(
x
,
-
1.
,
1.
)
x
=
(
x
+
1.
)
/
2.
x
=
x
.
permute
(
1
,
2
,
0
)
.
numpy
()
x
=
(
255
*
x
)
.
astype
(
np
.
uint8
)
x
=
Image
.
fromarray
(
x
)
if
not
x
.
mode
==
"RGB"
:
x
=
x
.
convert
(
"RGB"
)
return
x
image
=
model
.
decode
(
image
)
image
=
image
.
squeeze
(
0
)
image
=
custom_to_pil
(
image
)
return
image
def
sanitize_image
(
image
):
#Open image with PIL and get rid of alpha channel, scale to given res with center crop
image
=
Image
.
open
(
image
)
image
=
image
.
convert
(
'RGB'
)
return
image
class
StableDiffusionModel
(
nn
.
Module
):
class
StableDiffusionModel
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
...
@@ -68,11 +115,22 @@ class StableDiffusionModel(nn.Module):
...
@@ -68,11 +115,22 @@ class StableDiffusionModel(nn.Module):
typex
=
torch
.
float16
typex
=
torch
.
float16
else
:
else
:
typex
=
torch
.
float32
typex
=
torch
.
float32
self
.
model
=
model
.
to
(
config
.
device
)
.
to
(
typex
)
self
.
k_model
=
K
.
external
.
CompVisDenoiser
(
model
)
self
.
k_model
=
K
.
external
.
StableInterface
(
self
.
k_model
)
self
.
device
=
config
.
device
self
.
device
=
config
.
device
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
plms
=
PLMSSampler
(
model
)
self
.
plms
=
PLMSSampler
(
model
)
self
.
ddim
=
DDIMSampler
(
model
)
self
.
ddim
=
DDIMSampler
(
model
)
self
.
sampler_map
=
{
'plms'
:
self
.
plms
.
sample
,
'ddim'
:
self
.
ddim
.
sample
,
'k_euler'
:
K
.
sampling
.
sample_euler
,
'k_euler_ancestral'
:
K
.
sampling
.
sample_euler_ancestral
,
'k_heun'
:
K
.
sampling
.
sample_heun
,
'k_dpm_2'
:
K
.
sampling
.
sample_dpm_2
,
'k_dpm_2_ancestral'
:
K
.
sampling
.
sample_dpm_2_ancestral
,
'k_lms'
:
K
.
sampling
.
sample_lms
,
}
def
from_folder
(
self
,
folder
):
def
from_folder
(
self
,
folder
):
folder
=
Path
(
folder
)
folder
=
Path
(
folder
)
...
@@ -99,25 +157,41 @@ class StableDiffusionModel(nn.Module):
...
@@ -99,25 +157,41 @@ class StableDiffusionModel(nn.Module):
return
model
return
model
@
torch
.
no_grad
()
@
torch
.
no_grad
()
@
torch
.
autocast
(
"cuda"
,
enabled
=
True
,
dtype
=
torch
.
float16
)
def
sample
(
self
,
request
):
def
sample
(
self
,
request
):
request
=
DotMap
(
request
)
if
request
.
image
is
not
None
:
request
.
sampler
=
"ddim_img2img"
#enforce ddim for now
self
.
ddim
.
make_schedule
(
ddim_num_steps
=
request
.
steps
,
ddim_eta
=
request
.
ddim_eta
,
verbose
=
False
)
image
=
sanitize_image
(
request
.
image
)
image
=
image
.
resize
((
request
.
width
,
request
.
height
),
resample
=
Image
.
Resampling
.
LANCZOS
)
start_code
=
encode_image
(
image
,
self
.
model
.
first_stage_model
)
.
to
(
self
.
device
)
start_code
=
self
.
model
.
get_first_stage_encoding
(
start_code
)
print
(
start_code
.
shape
)
start_code
=
start_code
+
(
torch
.
randn_like
(
start_code
)
*
request
.
noise
)
t_enc
=
int
(
request
.
strength
*
request
.
steps
)
if
request
.
seed
is
not
None
:
if
request
.
seed
is
not
None
:
torch
.
manual_seed
(
request
.
seed
)
torch
.
manual_seed
(
request
.
seed
)
np
.
random
.
seed
(
request
.
seed
)
np
.
random
.
seed
(
request
.
seed
)
if
request
.
plms
:
if
request
.
sampler
.
startswith
(
"k_"
):
sampler
=
self
.
plms
sampler
=
"k-diffusion"
else
:
sampler
=
self
.
ddim
elif
request
.
sampler
==
'ddim_img2img'
:
sampler
=
'img2img'
start_code
=
None
else
:
if
request
.
fixed_code
:
sampler
=
"normal"
start_code
=
torch
.
randn
([
request
.
n_samples
,
if
request
.
image
is
None
:
request
.
latent_channels
,
start_code
=
None
request
.
height
//
request
.
downsampling_factor
,
if
request
.
fixed_code
or
sampler
==
"k-diffusion"
:
request
.
width
//
request
.
downsampling_factor
,
start_code
=
torch
.
randn
([
],
device
=
self
.
device
)
request
.
n_samples
,
request
.
latent_channels
,
request
.
height
//
request
.
downsampling_factor
,
request
.
width
//
request
.
downsampling_factor
,
],
device
=
self
.
device
)
prompt
=
[
request
.
prompt
]
*
request
.
n_samples
prompt
=
[
request
.
prompt
]
*
request
.
n_samples
prompt_condition
=
prompt_mixing
(
self
.
model
,
prompt
[
0
],
request
.
n_samples
)
prompt_condition
=
prompt_mixing
(
self
.
model
,
prompt
[
0
],
request
.
n_samples
)
...
@@ -131,9 +205,9 @@ class StableDiffusionModel(nn.Module):
...
@@ -131,9 +205,9 @@ class StableDiffusionModel(nn.Module):
request
.
height
//
request
.
downsampling_factor
,
request
.
height
//
request
.
downsampling_factor
,
request
.
width
//
request
.
downsampling_factor
request
.
width
//
request
.
downsampling_factor
]
]
with
torch
.
autocast
(
"cuda"
,
enabled
=
self
.
config
.
amp
)
:
if
sampler
==
"normal"
:
with
self
.
model
.
ema_scope
():
with
self
.
model
.
ema_scope
():
samples
,
_
=
s
ampler
.
sample
(
samples
,
_
=
s
elf
.
sampler_map
[
request
.
sampler
]
(
S
=
request
.
steps
,
S
=
request
.
steps
,
conditioning
=
prompt_condition
,
conditioning
=
prompt_condition
,
batch_size
=
request
.
n_samples
,
batch_size
=
request
.
n_samples
,
...
@@ -146,6 +220,18 @@ class StableDiffusionModel(nn.Module):
...
@@ -146,6 +220,18 @@ class StableDiffusionModel(nn.Module):
x_T
=
start_code
,
x_T
=
start_code
,
)
)
elif
sampler
==
'img2img'
:
with
self
.
model
.
ema_scope
():
start_code
=
self
.
ddim
.
stochastic_encode
(
start_code
,
torch
.
tensor
([
t_enc
]
*
request
.
n_samples
)
.
to
(
self
.
device
),
noise
=
None
)
samples
=
self
.
ddim
.
decode
(
start_code
,
prompt_condition
,
t_enc
,
unconditional_guidance_scale
=
request
.
scale
,
unconditional_conditioning
=
uc
)
elif
sampler
==
"k-diffusion"
:
with
self
.
model
.
ema_scope
():
sigmas
=
self
.
k_model
.
get_sigmas
(
request
.
steps
)
start_code
=
start_code
*
sigmas
[
0
]
extra_args
=
{
'cond'
:
prompt_condition
,
'uncond'
:
uc
,
'cond_scale'
:
request
.
scale
}
samples
=
self
.
sampler_map
[
request
.
sampler
](
self
.
k_model
,
start_code
,
sigmas
,
extra_args
=
extra_args
)
x_samples_ddim
=
self
.
model
.
decode_first_stage
(
samples
)
x_samples_ddim
=
self
.
model
.
decode_first_stage
(
samples
)
x_samples_ddim
=
torch
.
clamp
((
x_samples_ddim
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
x_samples_ddim
=
torch
.
clamp
((
x_samples_ddim
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment