Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
e4b4a9c4
Commit
e4b4a9c4
authored
Dec 18, 2023
by
Nuullll
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[IPEX] Slice SDPA into smaller chunks
parent
de03882d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
64 additions
and
2 deletions
+64
-2
modules/xpu_specific.py
modules/xpu_specific.py
+64
-2
No files found.
modules/xpu_specific.py
View file @
e4b4a9c4
...
@@ -27,6 +27,68 @@ def torch_xpu_gc():
...
@@ -27,6 +27,68 @@ def torch_xpu_gc():
has_xpu
=
check_for_xpu
()
has_xpu
=
check_for_xpu
()
# Arc GPU cannot allocate a single block larger than 4GB: https://github.com/intel/compute-runtime/issues/627
# Here we implement a slicing algorithm to split large batch size into smaller chunks,
# so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT.
# The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G,
# which is the best trade-off between VRAM usage and performance.
ARC_SINGLE_ALLOCATION_LIMIT
=
min
(
torch
.
xpu
.
get_device_properties
(
shared
.
cmd_opts
.
device_id
)
.
total_memory
//
8
,
4
*
1024
*
1024
*
1024
)
orig_sdp_attn_func
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
def
torch_xpu_scaled_dot_product_attention
(
query
,
key
,
value
,
attn_mask
=
None
,
dropout_p
=
0.0
,
is_causal
=
False
,
*
args
,
**
kwargs
):
# cast to same dtype first
key
=
key
.
to
(
query
.
dtype
)
value
=
value
.
to
(
query
.
dtype
)
N
=
query
.
shape
[:
-
2
]
# Batch size
L
=
query
.
size
(
-
2
)
# Target sequence length
E
=
query
.
size
(
-
1
)
# Embedding dimension of the query and key
S
=
key
.
size
(
-
2
)
# Source sequence length
Ev
=
value
.
size
(
-
1
)
# Embedding dimension of the value
total_batch_size
=
torch
.
numel
(
torch
.
empty
(
N
))
batch_size_limit
=
max
(
1
,
ARC_SINGLE_ALLOCATION_LIMIT
//
(
L
*
S
*
query
.
element_size
()))
if
total_batch_size
<=
batch_size_limit
:
return
orig_sdp_attn_func
(
query
,
key
,
value
,
attn_mask
,
dropout_p
,
is_causal
,
*
args
,
**
kwargs
)
query
=
torch
.
reshape
(
query
,
(
-
1
,
L
,
E
))
key
=
torch
.
reshape
(
key
,
(
-
1
,
S
,
E
))
value
=
torch
.
reshape
(
value
,
(
-
1
,
S
,
Ev
))
if
attn_mask
is
not
None
:
attn_mask
=
attn_mask
.
view
(
-
1
,
L
,
S
)
chunk_count
=
(
total_batch_size
+
batch_size_limit
-
1
)
//
batch_size_limit
outputs
=
[]
for
i
in
range
(
chunk_count
):
attn_mask_chunk
=
(
None
if
attn_mask
is
None
else
attn_mask
[
i
*
batch_size_limit
:
(
i
+
1
)
*
batch_size_limit
,
:,
:]
)
chunk_output
=
orig_sdp_attn_func
(
query
[
i
*
batch_size_limit
:
(
i
+
1
)
*
batch_size_limit
,
:,
:],
key
[
i
*
batch_size_limit
:
(
i
+
1
)
*
batch_size_limit
,
:,
:],
value
[
i
*
batch_size_limit
:
(
i
+
1
)
*
batch_size_limit
,
:,
:],
attn_mask_chunk
,
dropout_p
,
is_causal
,
*
args
,
**
kwargs
)
outputs
.
append
(
chunk_output
)
result
=
torch
.
cat
(
outputs
,
dim
=
0
)
return
torch
.
reshape
(
result
,
(
*
N
,
L
,
Ev
))
if
has_xpu
:
if
has_xpu
:
# W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device
# W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device
CondFunc
(
'torch.Generator'
,
CondFunc
(
'torch.Generator'
,
...
@@ -55,5 +117,5 @@ if has_xpu:
...
@@ -55,5 +117,5 @@ if has_xpu:
lambda
orig_func
,
tensors
,
dim
=
0
,
out
=
None
:
orig_func
([
t
.
to
(
tensors
[
0
]
.
dtype
)
for
t
in
tensors
],
dim
=
dim
,
out
=
out
),
lambda
orig_func
,
tensors
,
dim
=
0
,
out
=
None
:
orig_func
([
t
.
to
(
tensors
[
0
]
.
dtype
)
for
t
in
tensors
],
dim
=
dim
,
out
=
out
),
lambda
orig_func
,
tensors
,
dim
=
0
,
out
=
None
:
not
all
(
t
.
dtype
==
tensors
[
0
]
.
dtype
for
t
in
tensors
))
lambda
orig_func
,
tensors
,
dim
=
0
,
out
=
None
:
not
all
(
t
.
dtype
==
tensors
[
0
]
.
dtype
for
t
in
tensors
))
CondFunc
(
'torch.nn.functional.scaled_dot_product_attention'
,
CondFunc
(
'torch.nn.functional.scaled_dot_product_attention'
,
lambda
orig_func
,
query
,
key
,
value
,
attn_mask
=
None
,
dropout_p
=
0.0
,
is_causal
=
False
:
orig_func
(
query
,
key
.
to
(
query
.
dtype
),
value
.
to
(
query
.
dtype
),
attn_mask
,
dropout_p
,
is_causal
),
lambda
orig_func
,
*
args
,
**
kwargs
:
torch_xpu_scaled_dot_product_attention
(
*
args
,
**
kwargs
),
lambda
orig_func
,
query
,
key
,
value
,
attn_mask
=
None
,
dropout_p
=
0.0
,
is_causal
=
False
:
query
.
dtype
!=
key
.
dtype
or
query
.
dtype
!=
value
.
dtype
)
lambda
orig_func
,
query
,
*
args
,
**
kwargs
:
query
.
is_xpu
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment