Commit 42e478be authored by novelailab's avatar novelailab

bfloat16 because it nans

parent 1a6751df
...@@ -124,7 +124,7 @@ class DalleMiniModel(nn.Module): ...@@ -124,7 +124,7 @@ class DalleMiniModel(nn.Module):
self.config = config self.config = config
self.model = MinDalle( self.model = MinDalle(
models_root=config.model_path, models_root=config.model_path,
dtype=torch.float16, dtype=torch.bfloat16,
device='cuda', device='cuda',
is_mega=True, is_mega=True,
is_reusable=True is_reusable=True
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment