Commit a05caa11 authored by Arda Cihaner's avatar Arda Cihaner

gpu fix

parent f809c1e8
......@@ -138,8 +138,8 @@ class VisionTransformer(base_image.BaseVisionModel):
'activation': gelu_new,
'image_size': (224, 224),
'eps': 1e-5,
'device': torch.device('cpu'),
'dtype': torch.float32,
'device': torch.device('gpu'),
'dtype': torch.float16,
}
super().__init__(self.default_config)
self.embed = ViTEmbeds(self.config)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment