Commit 53d67088 authored by drhead's avatar drhead Committed by GitHub

Patch timestep embedding to create tensor on-device

parent ddb28b33
import torch import torch
from packaging import version from packaging import version
from einops import repeat
import math
from modules import devices from modules import devices
from modules.sd_hijack_utils import CondFunc from modules.sd_hijack_utils import CondFunc
...@@ -48,6 +50,30 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): ...@@ -48,6 +50,30 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float() return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
# Monkey patch to create timestep embed tensor on device, avoiding a block.
def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding
class GELUHijack(torch.nn.GELU, torch.nn.Module): class GELUHijack(torch.nn.GELU, torch.nn.Module):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
torch.nn.GELU.__init__(self, *args, **kwargs) torch.nn.GELU.__init__(self, *args, **kwargs)
...@@ -69,6 +95,7 @@ def hijack_ddpm_edit(): ...@@ -69,6 +95,7 @@ def hijack_ddpm_edit():
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available(): if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment