Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
18ca987c
Commit
18ca987c
authored
Jan 05, 2024
by
Kohaku-Blueleaf
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add general forward method for all modules.
parent
a06dab8d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
7 deletions
+39
-7
extensions-builtin/Lora/network.py
extensions-builtin/Lora/network.py
+33
-1
extensions-builtin/Lora/networks.py
extensions-builtin/Lora/networks.py
+6
-6
No files found.
extensions-builtin/Lora/network.py
View file @
18ca987c
...
...
@@ -3,6 +3,10 @@ import os
from
collections
import
namedtuple
import
enum
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
modules
import
sd_models
,
cache
,
errors
,
hashes
,
shared
NetworkWeights
=
namedtuple
(
'NetworkWeights'
,
[
'network_key'
,
'sd_key'
,
'w'
,
'sd_module'
])
...
...
@@ -115,6 +119,29 @@ class NetworkModule:
if
hasattr
(
self
.
sd_module
,
'weight'
):
self
.
shape
=
self
.
sd_module
.
weight
.
shape
self
.
ops
=
None
self
.
extra_kwargs
=
{}
if
isinstance
(
self
.
sd_module
,
nn
.
Conv2d
):
self
.
ops
=
F
.
conv2d
self
.
extra_kwargs
=
{
'stride'
:
self
.
sd_module
.
stride
,
'padding'
:
self
.
sd_module
.
padding
}
elif
isinstance
(
self
.
sd_module
,
nn
.
Linear
):
self
.
ops
=
F
.
linear
elif
isinstance
(
self
.
sd_module
,
nn
.
LayerNorm
):
self
.
ops
=
F
.
layer_norm
self
.
extra_kwargs
=
{
'normalized_shape'
:
self
.
sd_module
.
normalized_shape
,
'eps'
:
self
.
sd_module
.
eps
}
elif
isinstance
(
self
.
sd_module
,
nn
.
GroupNorm
):
self
.
ops
=
F
.
group_norm
self
.
extra_kwargs
=
{
'num_groups'
:
self
.
sd_module
.
num_groups
,
'eps'
:
self
.
sd_module
.
eps
}
self
.
dim
=
None
self
.
bias
=
weights
.
w
.
get
(
"bias"
)
self
.
alpha
=
weights
.
w
[
"alpha"
]
.
item
()
if
"alpha"
in
weights
.
w
else
None
...
...
@@ -155,5 +182,10 @@ class NetworkModule:
raise
NotImplementedError
()
def
forward
(
self
,
x
,
y
):
raise
NotImplementedError
()
"""A general forward implementation for all modules"""
if
self
.
ops
is
None
:
raise
NotImplementedError
()
else
:
updown
,
ex_bias
=
self
.
calc_updown
(
self
.
sd_module
.
weight
)
return
y
+
self
.
ops
(
x
,
weight
=
updown
,
bias
=
ex_bias
,
**
self
.
extra_kwargs
)
extensions-builtin/Lora/networks.py
View file @
18ca987c
...
...
@@ -458,23 +458,23 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self
.
network_current_names
=
wanted_names
def
network_forward
(
module
,
input
,
original_forward
):
def
network_forward
(
org_
module
,
input
,
original_forward
):
"""
Old way of applying Lora by executing operations during layer's forward.
Stacking many loras this way results in big performance degradation.
"""
if
len
(
loaded_networks
)
==
0
:
return
original_forward
(
module
,
input
)
return
original_forward
(
org_
module
,
input
)
input
=
devices
.
cond_cast_unet
(
input
)
network_restore_weights_from_backup
(
module
)
network_reset_cached_weight
(
module
)
network_restore_weights_from_backup
(
org_
module
)
network_reset_cached_weight
(
org_
module
)
y
=
original_forward
(
module
,
input
)
y
=
original_forward
(
org_
module
,
input
)
network_layer_name
=
getattr
(
module
,
'network_layer_name'
,
None
)
network_layer_name
=
getattr
(
org_
module
,
'network_layer_name'
,
None
)
for
lora
in
loaded_networks
:
module
=
lora
.
modules
.
get
(
network_layer_name
,
None
)
if
module
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment