Commit a05caa11 authored by Arda Cihaner's avatar Arda Cihaner

gpu fix

parent f809c1e8
...@@ -138,8 +138,8 @@ class VisionTransformer(base_image.BaseVisionModel): ...@@ -138,8 +138,8 @@ class VisionTransformer(base_image.BaseVisionModel):
'activation': gelu_new, 'activation': gelu_new,
'image_size': (224, 224), 'image_size': (224, 224),
'eps': 1e-5, 'eps': 1e-5,
'device': torch.device('cpu'), 'device': torch.device('gpu'),
'dtype': torch.float32, 'dtype': torch.float16,
} }
super().__init__(self.default_config) super().__init__(self.default_config)
self.embed = ViTEmbeds(self.config) self.embed = ViTEmbeds(self.config)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment