Commit bdaa36c8 authored by brkirch's avatar brkirch

When device is MPS, use CPU for GFPGAN instead

GFPGAN will not work if the device is MPS, so default to CPU instead.
parent 84e97a98
...@@ -34,7 +34,7 @@ errors.run(enable_tf32, "Enabling TF32") ...@@ -34,7 +34,7 @@ errors.run(enable_tf32, "Enabling TF32")
device = get_optimal_device() device = get_optimal_device()
device_codeformer = cpu if has_mps else device device_gfpgan = device_codeformer = cpu if device.type == 'mps' else device
def randn(seed, shape): def randn(seed, shape):
......
...@@ -21,7 +21,7 @@ def gfpgann(): ...@@ -21,7 +21,7 @@ def gfpgann():
global loaded_gfpgan_model global loaded_gfpgan_model
global model_path global model_path
if loaded_gfpgan_model is not None: if loaded_gfpgan_model is not None:
loaded_gfpgan_model.gfpgan.to(shared.device) loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
return loaded_gfpgan_model return loaded_gfpgan_model
if gfpgan_constructor is None: if gfpgan_constructor is None:
...@@ -36,8 +36,8 @@ def gfpgann(): ...@@ -36,8 +36,8 @@ def gfpgann():
else: else:
print("Unable to load gfpgan model!") print("Unable to load gfpgan model!")
return None return None
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
model.gfpgan.to(shared.device) model.gfpgan.to(devices.device_gfpgan)
loaded_gfpgan_model = model loaded_gfpgan_model = model
return model return model
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment