Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
d4d3134f
Commit
d4d3134f
authored
Oct 28, 2023
by
Kohaku-Blueleaf
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
ManualCast for 10/16 series gpu
parent
0beb131c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
64 additions
and
16 deletions
+64
-16
modules/devices.py
modules/devices.py
+51
-6
modules/processing.py
modules/processing.py
+1
-1
modules/sd_models.py
modules/sd_models.py
+12
-9
No files found.
modules/devices.py
View file @
d4d3134f
...
@@ -16,6 +16,23 @@ def has_mps() -> bool:
...
@@ -16,6 +16,23 @@ def has_mps() -> bool:
return
mac_specific
.
has_mps
return
mac_specific
.
has_mps
def
cuda_no_autocast
(
device_id
=
None
)
->
bool
:
if
device_id
is
None
:
device_id
=
get_cuda_device_id
()
return
(
torch
.
cuda
.
get_device_capability
(
device_id
)
==
(
7
,
5
)
and
torch
.
cuda
.
get_device_name
(
device_id
)
.
startswith
(
"NVIDIA GeForce GTX 16"
)
)
def
get_cuda_device_id
():
return
(
int
(
shared
.
cmd_opts
.
device_id
)
if
shared
.
cmd_opts
.
device_id
is
not
None
and
shared
.
cmd_opts
.
device_id
.
isdigit
()
else
0
)
or
torch
.
cuda
.
current_device
()
def
get_cuda_device_string
():
def
get_cuda_device_string
():
if
shared
.
cmd_opts
.
device_id
is
not
None
:
if
shared
.
cmd_opts
.
device_id
is
not
None
:
return
f
"cuda:{shared.cmd_opts.device_id}"
return
f
"cuda:{shared.cmd_opts.device_id}"
...
@@ -60,8 +77,7 @@ def enable_tf32():
...
@@ -60,8 +77,7 @@ def enable_tf32():
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
device_id
=
(
int
(
shared
.
cmd_opts
.
device_id
)
if
shared
.
cmd_opts
.
device_id
is
not
None
and
shared
.
cmd_opts
.
device_id
.
isdigit
()
else
0
)
or
torch
.
cuda
.
current_device
()
if
cuda_no_autocast
():
if
torch
.
cuda
.
get_device_capability
(
device_id
)
==
(
7
,
5
)
and
torch
.
cuda
.
get_device_name
(
device_id
)
.
startswith
(
"NVIDIA GeForce GTX 16"
):
torch
.
backends
.
cudnn
.
benchmark
=
True
torch
.
backends
.
cudnn
.
benchmark
=
True
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
...
@@ -92,15 +108,44 @@ def cond_cast_float(input):
...
@@ -92,15 +108,44 @@ def cond_cast_float(input):
nv_rng
=
None
nv_rng
=
None
patch_module_list
=
[
torch
.
nn
.
Linear
,
def
autocast
(
disable
=
False
,
unet
=
False
):
torch
.
nn
.
Conv2d
,
torch
.
nn
.
MultiheadAttention
,
torch
.
nn
.
GroupNorm
,
torch
.
nn
.
LayerNorm
,
]
@
contextlib
.
contextmanager
def
manual_autocast
():
def
manual_cast_forward
(
self
,
*
args
,
**
kwargs
):
org_dtype
=
next
(
self
.
parameters
())
.
dtype
self
.
to
(
dtype
)
result
=
self
.
org_forward
(
*
args
,
**
kwargs
)
self
.
to
(
org_dtype
)
return
result
for
module_type
in
patch_module_list
:
org_forward
=
module_type
.
forward
module_type
.
forward
=
manual_cast_forward
module_type
.
org_forward
=
org_forward
try
:
yield
None
finally
:
for
module_type
in
patch_module_list
:
module_type
.
forward
=
module_type
.
org_forward
def
autocast
(
disable
=
False
):
print
(
fp8
,
dtype
,
shared
.
cmd_opts
.
precision
,
device
)
if
disable
:
if
disable
:
return
contextlib
.
nullcontext
()
return
contextlib
.
nullcontext
()
if
unet
and
fp8
and
device
==
cpu
:
if
fp8
and
device
==
cpu
:
return
torch
.
autocast
(
"cpu"
,
dtype
=
torch
.
bfloat16
,
enabled
=
True
)
return
torch
.
autocast
(
"cpu"
,
dtype
=
torch
.
bfloat16
,
enabled
=
True
)
if
fp8
and
(
dtype
==
torch
.
float32
or
shared
.
cmd_opts
.
precision
==
"full"
or
cuda_no_autocast
()):
return
manual_autocast
()
if
dtype
==
torch
.
float32
or
shared
.
cmd_opts
.
precision
==
"full"
:
if
dtype
==
torch
.
float32
or
shared
.
cmd_opts
.
precision
==
"full"
:
return
contextlib
.
nullcontext
()
return
contextlib
.
nullcontext
()
...
...
modules/processing.py
View file @
d4d3134f
...
@@ -865,7 +865,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
...
@@ -865,7 +865,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if
p
.
n_iter
>
1
:
if
p
.
n_iter
>
1
:
shared
.
state
.
job
=
f
"Batch {n+1} out of {p.n_iter}"
shared
.
state
.
job
=
f
"Batch {n+1} out of {p.n_iter}"
with
devices
.
without_autocast
()
if
devices
.
unet_needs_upcast
else
devices
.
autocast
(
unet
=
True
):
with
devices
.
without_autocast
()
if
devices
.
unet_needs_upcast
else
devices
.
autocast
():
samples_ddim
=
p
.
sample
(
conditioning
=
p
.
c
,
unconditional_conditioning
=
p
.
uc
,
seeds
=
p
.
seeds
,
subseeds
=
p
.
subseeds
,
subseed_strength
=
p
.
subseed_strength
,
prompts
=
p
.
prompts
)
samples_ddim
=
p
.
sample
(
conditioning
=
p
.
c
,
unconditional_conditioning
=
p
.
uc
,
seeds
=
p
.
seeds
,
subseeds
=
p
.
subseeds
,
subseed_strength
=
p
.
subseed_strength
,
prompts
=
p
.
prompts
)
if
getattr
(
samples_ddim
,
'already_decoded'
,
False
):
if
getattr
(
samples_ddim
,
'already_decoded'
,
False
):
...
...
modules/sd_models.py
View file @
d4d3134f
...
@@ -403,23 +403,26 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
...
@@ -403,23 +403,26 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if
enable_fp8
:
if
enable_fp8
:
devices
.
fp8
=
True
devices
.
fp8
=
True
if
devices
.
device
==
devices
.
cpu
:
for
module
in
model
.
model
.
diffusion_model
.
modules
():
if
isinstance
(
module
,
torch
.
nn
.
Conv2d
):
module
.
to
(
torch
.
float8_e4m3fn
)
elif
isinstance
(
module
,
torch
.
nn
.
Linear
):
module
.
to
(
torch
.
float8_e4m3fn
)
timer
.
record
(
"apply fp8 unet for cpu"
)
else
:
if
model
.
is_sdxl
:
if
model
.
is_sdxl
:
cond_stage
=
model
.
conditioner
cond_stage
=
model
.
conditioner
else
:
else
:
cond_stage
=
model
.
cond_stage_model
cond_stage
=
model
.
cond_stage_model
for
module
in
cond_stage
.
modules
():
for
module
in
cond_stage
.
modules
():
if
isinstance
(
module
,
torch
.
nn
.
Linear
):
if
isinstance
(
module
,
torch
.
nn
.
Linear
):
module
.
to
(
torch
.
float8_e4m3fn
)
module
.
to
(
torch
.
float8_e4m3fn
)
if
devices
.
device
==
devices
.
cpu
:
for
module
in
model
.
model
.
diffusion_model
.
modules
():
if
isinstance
(
module
,
torch
.
nn
.
Conv2d
):
module
.
to
(
torch
.
float8_e4m3fn
)
elif
isinstance
(
module
,
torch
.
nn
.
Linear
):
module
.
to
(
torch
.
float8_e4m3fn
)
else
:
model
.
model
.
diffusion_model
=
model
.
model
.
diffusion_model
.
to
(
torch
.
float8_e4m3fn
)
model
.
model
.
diffusion_model
=
model
.
model
.
diffusion_model
.
to
(
torch
.
float8_e4m3fn
)
timer
.
record
(
"apply fp8 unet"
)
timer
.
record
(
"apply fp8"
)
else
:
devices
.
fp8
=
False
devices
.
unet_needs_upcast
=
shared
.
cmd_opts
.
upcast_sampling
and
devices
.
dtype
==
torch
.
float16
and
devices
.
dtype_unet
==
torch
.
float16
devices
.
unet_needs_upcast
=
shared
.
cmd_opts
.
upcast_sampling
and
devices
.
dtype
==
torch
.
float16
and
devices
.
dtype_unet
==
torch
.
float16
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment