Commit f261a4a5 authored by AUTOMATIC's avatar AUTOMATIC

use selected device instead of always cuda for UniPC sampler

parent a11ce2b9
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
import torch import torch
from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
from modules import shared from modules import shared, devices
class UniPCSampler(object): class UniPCSampler(object):
def __init__(self, model, **kwargs): def __init__(self, model, **kwargs):
...@@ -16,8 +17,8 @@ class UniPCSampler(object): ...@@ -16,8 +17,8 @@ class UniPCSampler(object):
def register_buffer(self, name, attr): def register_buffer(self, name, attr):
if type(attr) == torch.Tensor: if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"): if attr.device != devices.device:
attr = attr.to(torch.device("cuda")) attr = attr.to(devices.device)
setattr(self, name, attr) setattr(self, name, attr)
def set_hooks(self, before_sample, after_sample, after_update): def set_hooks(self, before_sample, after_sample, after_update):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment