Commit bdaa36c8 authored by brkirch's avatar brkirch

When device is MPS, use CPU for GFPGAN instead

GFPGAN will not work if the device is MPS, so default to CPU instead.
parent 84e97a98
......@@ -34,7 +34,7 @@ errors.run(enable_tf32, "Enabling TF32")
device = get_optimal_device()
device_codeformer = cpu if has_mps else device
device_gfpgan = device_codeformer = cpu if device.type == 'mps' else device
def randn(seed, shape):
......
......@@ -21,7 +21,7 @@ def gfpgann():
global loaded_gfpgan_model
global model_path
if loaded_gfpgan_model is not None:
loaded_gfpgan_model.gfpgan.to(shared.device)
loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
return loaded_gfpgan_model
if gfpgan_constructor is None:
......@@ -36,8 +36,8 @@ def gfpgann():
else:
print("Unable to load gfpgan model!")
return None
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
model.gfpgan.to(shared.device)
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
model.gfpgan.to(devices.device_gfpgan)
loaded_gfpgan_model = model
return model
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment