Commit 558c30dc authored by Arda Cihaner's avatar Arda Cihaner

Removing attention mask from encoder

parent e2a4d6b1
...@@ -5,11 +5,8 @@ from basedformer.utils import * ...@@ -5,11 +5,8 @@ from basedformer.utils import *
from basedformer.models import base_image from basedformer.models import base_image
import einops import einops
def _attn(query, key, value, causal_mask, masked_bias, def _attn(query, key, value, attention_mask=None, scale_attn=None):
attention_mask=None, scale_attn=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
attn_weights = attn_weights / scale_attn attn_weights = attn_weights / scale_attn
if attention_mask is not None: if attention_mask is not None:
...@@ -26,9 +23,6 @@ class SelfAttention(nn.Module): ...@@ -26,9 +23,6 @@ class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later. # Code copied from HF, might want to sanity check later.
def __init__(self, config): def __init__(self, config):
nn.Module.__init__(self) nn.Module.__init__(self)
max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool()
self.head_dim = config.hidden_dim // config.n_head self.head_dim = config.hidden_dim // config.n_head
self.rotary_dim = self.head_dim // 4 self.rotary_dim = self.head_dim // 4
self.hidden_dim = config.hidden_dim self.hidden_dim = config.hidden_dim
...@@ -37,8 +31,6 @@ class SelfAttention(nn.Module): ...@@ -37,8 +31,6 @@ class SelfAttention(nn.Module):
dtype = config.dtype dtype = config.dtype
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float())) self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = False attn_bias = False
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype) self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype) self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
...@@ -59,11 +51,8 @@ class SelfAttention(nn.Module): ...@@ -59,11 +51,8 @@ class SelfAttention(nn.Module):
torch.cat([k, key], dim=-2) # cat key torch.cat([k, key], dim=-2) # cat key
torch.cat([v, value], dim=-2) # cat value torch.cat([v, value], dim=-2) # cat value
query_length, key_length = query.size(-2), key.size(-2) # seq_len, seq_len
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length]
x = _attn( x = _attn(
query, key, value, causal_mask, self.masked_bias, None, self.scale_attn query, key, value, None, self.scale_attn
) )
x = x.transpose(1, 2).contiguous().view(B, S, H) x = x.transpose(1, 2).contiguous().view(B, S, H)
...@@ -135,11 +124,11 @@ class VisionTransformer(base_image.BaseVisionModel): ...@@ -135,11 +124,11 @@ class VisionTransformer(base_image.BaseVisionModel):
'patch_size': 16, 'patch_size': 16,
'hidden_dim': 768, 'hidden_dim': 768,
'n_classes' : 1000, 'n_classes' : 1000,
'activation': gelu_new, 'activation': F.gelu,
'image_size': (224, 224), 'image_size': (224, 224),
'eps': 1e-5, 'eps': 1e-5,
'device': torch.device('cuda'), 'device': torch.device('cpu'),
'dtype': torch.float16, 'dtype': torch.float32,
} }
super().__init__(self.default_config) super().__init__(self.default_config)
self.embed = ViTEmbeds(self.config) self.embed = ViTEmbeds(self.config)
...@@ -151,9 +140,7 @@ class VisionTransformer(base_image.BaseVisionModel): ...@@ -151,9 +140,7 @@ class VisionTransformer(base_image.BaseVisionModel):
def forward(self, x): def forward(self, x):
p_size = self.config.patch_size p_size = self.config.patch_size
patches = einops.rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=p_size, s2=p_size) patches = einops.rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=p_size, s2=p_size)
print(patches.shape)
patches = self.embed(patches) patches = self.embed(patches)
print(patches.shape)
for encoder in self.encoder_layers: for encoder in self.encoder_layers:
patches = encoder(patches) patches = encoder(patches)
return self.mlp_head(patches) return self.mlp_head(patches)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment