Commit 043d2edc authored by Kohaku-Blueleaf's avatar Kohaku-Blueleaf

Better naming

parent f383af27
......@@ -128,7 +128,7 @@ def manual_cast_forward(self, *args, **kwargs):
@contextlib.contextmanager
def manual_autocast():
def manual_cast():
for module_type in patch_module_list:
org_forward = module_type.forward
module_type.forward = manual_cast_forward
......@@ -148,10 +148,10 @@ def autocast(disable=False):
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()):
return manual_autocast()
return manual_cast()
if has_mps() and shared.cmd_opts.precision != "full":
return manual_autocast()
return manual_cast()
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment