Commit 42e478be authored by novelailab's avatar novelailab

bfloat16 because it nans

parent 1a6751df
......@@ -124,7 +124,7 @@ class DalleMiniModel(nn.Module):
self.config = config
self.model = MinDalle(
models_root=config.model_path,
dtype=torch.float16,
dtype=torch.bfloat16,
device='cuda',
is_mega=True,
is_reusable=True
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment