Commit fca42949 authored by AUTOMATIC1111's avatar AUTOMATIC1111

rework torchsde._brownian.brownian_interval replacement to use...

rework torchsde._brownian.brownian_interval replacement to use device.randn_local and respect the NV setting.
parent 84b6fcd0
...@@ -71,14 +71,17 @@ def enable_tf32(): ...@@ -71,14 +71,17 @@ def enable_tf32():
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
errors.run(enable_tf32, "Enabling TF32") errors.run(enable_tf32, "Enabling TF32")
cpu = torch.device("cpu") cpu: torch.device = torch.device("cpu")
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None device: torch.device = None
dtype = torch.float16 device_interrogate: torch.device = None
dtype_vae = torch.float16 device_gfpgan: torch.device = None
dtype_unet = torch.float16 device_esrgan: torch.device = None
device_codeformer: torch.device = None
dtype: torch.dtype = torch.float16
dtype_vae: torch.dtype = torch.float16
dtype_unet: torch.dtype = torch.float16
unet_needs_upcast = False unet_needs_upcast = False
...@@ -94,6 +97,10 @@ nv_rng = None ...@@ -94,6 +97,10 @@ nv_rng = None
def randn(seed, shape): def randn(seed, shape):
"""Generate a tensor with random numbers from a normal distribution using seed.
Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
from modules.shared import opts from modules.shared import opts
manual_seed(seed) manual_seed(seed)
...@@ -107,7 +114,27 @@ def randn(seed, shape): ...@@ -107,7 +114,27 @@ def randn(seed, shape):
return torch.randn(shape, device=device) return torch.randn(shape, device=device)
def randn_local(seed, shape):
"""Generate a tensor with random numbers from a normal distribution using seed.
Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
from modules.shared import opts
if opts.randn_source == "NV":
rng = rng_philox.Generator(seed)
return torch.asarray(rng.randn(shape), device=device)
local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device
local_generator = torch.Generator(local_device).manual_seed(int(seed))
return torch.randn(shape, device=local_device, generator=local_generator).to(device)
def randn_like(x): def randn_like(x):
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
Use either randn() or manual_seed() to initialize the generator."""
from modules.shared import opts from modules.shared import opts
if opts.randn_source == "NV": if opts.randn_source == "NV":
...@@ -120,6 +147,10 @@ def randn_like(x): ...@@ -120,6 +147,10 @@ def randn_like(x):
def randn_without_seed(shape): def randn_without_seed(shape):
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
Use either randn() or manual_seed() to initialize the generator."""
from modules.shared import opts from modules.shared import opts
if opts.randn_source == "NV": if opts.randn_source == "NV":
...@@ -132,6 +163,7 @@ def randn_without_seed(shape): ...@@ -132,6 +163,7 @@ def randn_without_seed(shape):
def manual_seed(seed): def manual_seed(seed):
"""Set up a global random number generator using the specified seed."""
from modules.shared import opts from modules.shared import opts
if opts.randn_source == "NV": if opts.randn_source == "NV":
......
...@@ -2,10 +2,8 @@ from collections import namedtuple ...@@ -2,10 +2,8 @@ from collections import namedtuple
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
from modules.shared import opts, state from modules.shared import opts, state
import modules.shared as shared
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
...@@ -85,11 +83,13 @@ class InterruptedException(BaseException): ...@@ -85,11 +83,13 @@ class InterruptedException(BaseException):
pass pass
if opts.randn_source == "CPU": def replace_torchsde_browinan():
import torchsde._brownian.brownian_interval import torchsde._brownian.brownian_interval
def torchsde_randn(size, dtype, device, seed): def torchsde_randn(size, dtype, device, seed):
generator = torch.Generator(devices.cpu).manual_seed(int(seed)) return devices.randn_local(seed, size).to(device=device, dtype=dtype)
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
torchsde._brownian.brownian_interval._randn = torchsde_randn torchsde._brownian.brownian_interval._randn = torchsde_randn
replace_torchsde_browinan()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment