Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
H
Hydra Node Http
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Hydra Node Http
Commits
d3e32e79
Commit
d3e32e79
authored
Aug 19, 2022
by
kurumuz
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
StableInterface for k-diffusion
parent
b43d6ea1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
1 deletion
+20
-1
hydra_node/models.py
hydra_node/models.py
+20
-1
No files found.
hydra_node/models.py
View file @
d3e32e79
...
@@ -105,6 +105,25 @@ def sanitize_image(image):
...
@@ -105,6 +105,25 @@ def sanitize_image(image):
image
=
image
.
convert
(
'RGB'
)
image
=
image
.
convert
(
'RGB'
)
return
image
return
image
class
StableInterface
(
nn
.
Module
):
def
__init__
(
self
,
model
,
thresholder
=
None
):
super
()
.
__init__
()
self
.
inner_model
=
model
self
.
sigma_to_t
=
model
.
sigma_to_t
self
.
thresholder
=
thresholder
self
.
get_sigmas
=
model
.
get_sigmas
@
torch
.
no_grad
()
def
forward
(
self
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
):
x_two
=
torch
.
cat
([
x
]
*
2
)
sigma_two
=
torch
.
cat
([
sigma
]
*
2
)
cond_full
=
torch
.
cat
([
uncond
,
cond
])
uncond
,
cond
=
self
.
inner_model
(
x_two
,
sigma_two
,
cond
=
cond_full
)
.
chunk
(
2
)
x_0
=
uncond
+
(
cond
-
uncond
)
*
cond_scale
if
self
.
thresholder
is
not
None
:
x_0
=
self
.
thresholder
(
x_0
)
return
x_0
class
StableDiffusionModel
(
nn
.
Module
):
class
StableDiffusionModel
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
...
@@ -116,7 +135,7 @@ class StableDiffusionModel(nn.Module):
...
@@ -116,7 +135,7 @@ class StableDiffusionModel(nn.Module):
else
:
else
:
typex
=
torch
.
float32
typex
=
torch
.
float32
self
.
k_model
=
K
.
external
.
CompVisDenoiser
(
model
)
self
.
k_model
=
K
.
external
.
CompVisDenoiser
(
model
)
self
.
k_model
=
K
.
external
.
StableInterface
(
self
.
k_model
)
self
.
k_model
=
StableInterface
(
self
.
k_model
)
self
.
device
=
config
.
device
self
.
device
=
config
.
device
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
plms
=
PLMSSampler
(
model
)
self
.
plms
=
PLMSSampler
(
model
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment