Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
209c26a1
Commit
209c26a1
authored
Jan 09, 2024
by
Kohaku-Blueleaf
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
improve efficiency and support more device
parent
6869d958
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
17 deletions
+44
-17
modules/devices.py
modules/devices.py
+43
-17
modules/shared_init.py
modules/shared_init.py
+1
-0
No files found.
modules/devices.py
View file @
209c26a1
...
...
@@ -110,6 +110,7 @@ device_codeformer: torch.device = None
dtype
:
torch
.
dtype
=
torch
.
float16
dtype_vae
:
torch
.
dtype
=
torch
.
float16
dtype_unet
:
torch
.
dtype
=
torch
.
float16
dtype_inference
:
torch
.
dtype
=
torch
.
float16
unet_needs_upcast
=
False
...
...
@@ -131,21 +132,49 @@ patch_module_list = [
]
def
manual_cast_forward
(
self
,
*
args
,
**
kwargs
):
org_dtype
=
torch_utils
.
get_param
(
self
)
.
dtype
self
.
to
(
dtype
)
args
=
[
arg
.
to
(
dtype
)
if
isinstance
(
arg
,
torch
.
Tensor
)
else
arg
for
arg
in
args
]
kwargs
=
{
k
:
v
.
to
(
dtype
)
if
isinstance
(
v
,
torch
.
Tensor
)
else
v
for
k
,
v
in
kwargs
.
items
()}
result
=
self
.
org_forward
(
*
args
,
**
kwargs
)
self
.
to
(
org_dtype
)
return
result
def
manual_cast_forward
(
target_dtype
):
def
forward_wrapper
(
self
,
*
args
,
**
kwargs
):
org_dtype
=
torch_utils
.
get_param
(
self
)
.
dtype
if
not
target_dtype
==
org_dtype
==
dtype_inference
:
self
.
to
(
target_dtype
)
args
=
[
arg
.
to
(
target_dtype
)
if
isinstance
(
arg
,
torch
.
Tensor
)
else
arg
for
arg
in
args
]
kwargs
=
{
k
:
v
.
to
(
target_dtype
)
if
isinstance
(
v
,
torch
.
Tensor
)
else
v
for
k
,
v
in
kwargs
.
items
()
}
result
=
self
.
org_forward
(
*
args
,
**
kwargs
)
self
.
to
(
org_dtype
)
if
target_dtype
!=
dtype_inference
:
if
isinstance
(
result
,
tuple
):
result
=
tuple
(
i
.
to
(
dtype_inference
)
if
isinstance
(
i
,
torch
.
Tensor
)
else
i
for
i
in
result
)
elif
isinstance
(
result
,
torch
.
Tensor
):
result
=
result
.
to
(
dtype_inference
)
return
result
return
forward_wrapper
@
contextlib
.
contextmanager
def
manual_cast
():
def
manual_cast
(
target_dtype
):
for
module_type
in
patch_module_list
:
org_forward
=
module_type
.
forward
module_type
.
forward
=
manual_cast_forward
if
module_type
==
torch
.
nn
.
MultiheadAttention
and
has_xpu
():
module_type
.
forward
=
manual_cast_forward
(
torch
.
float32
)
else
:
module_type
.
forward
=
manual_cast_forward
(
target_dtype
)
module_type
.
org_forward
=
org_forward
try
:
yield
None
...
...
@@ -161,15 +190,12 @@ def autocast(disable=False):
if
fp8
and
device
==
cpu
:
return
torch
.
autocast
(
"cpu"
,
dtype
=
torch
.
bfloat16
,
enabled
=
True
)
if
fp8
and
(
dtype
==
torch
.
float32
or
shared
.
cmd_opts
.
precision
==
"full"
or
cuda_no_autocast
()):
return
manual_cast
()
if
has_mps
()
and
shared
.
cmd_opts
.
precision
!=
"full"
:
return
manual_cast
()
if
dtype
==
torch
.
float32
or
shared
.
cmd_opts
.
precision
==
"full"
:
if
dtype
==
torch
.
float32
and
shared
.
cmd_opts
.
precision
==
"full"
:
return
contextlib
.
nullcontext
()
if
has_xpu
()
or
has_mps
()
or
cuda_no_autocast
():
return
manual_cast
(
dtype_inference
)
return
torch
.
autocast
(
"cuda"
)
...
...
modules/shared_init.py
View file @
209c26a1
...
...
@@ -29,6 +29,7 @@ def initialize():
devices
.
dtype
=
torch
.
float32
if
cmd_opts
.
no_half
else
torch
.
float16
devices
.
dtype_vae
=
torch
.
float32
if
cmd_opts
.
no_half
or
cmd_opts
.
no_half_vae
else
torch
.
float16
devices
.
dtype_inference
=
torch
.
float32
if
cmd_opts
.
precision
==
'full'
else
devices
.
dtype
shared
.
device
=
devices
.
device
shared
.
weight_load_location
=
None
if
cmd_opts
.
lowram
else
"cpu"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment