Commit 1aff393b authored by novelailab's avatar novelailab

float32 then? probably attn needs to be fp32?

parent 42e478be
...@@ -124,7 +124,7 @@ class DalleMiniModel(nn.Module): ...@@ -124,7 +124,7 @@ class DalleMiniModel(nn.Module):
self.config = config self.config = config
self.model = MinDalle( self.model = MinDalle(
models_root=config.model_path, models_root=config.model_path,
dtype=torch.bfloat16, dtype=torch.float32,
device='cuda', device='cuda',
is_mega=True, is_mega=True,
is_reusable=True is_reusable=True
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment