Commit f809c1e8 authored by Arda Cihaner's avatar Arda Cihaner

ViT and ResNet

parent c58dfef8
import torch.nn as nn
from dotmap import DotMap
class BaseVisionModel(nn.Module):
def __init__(self, user_config):
super().__init__()
self.user_config = user_config
self.config = self.configure_model()
config = self.config
def configure_model(self):
full_config = {}
if not hasattr(self, 'default_config'):
raise ValueError("No default config found, add one for the model to function")
#apply defaults
for k, v in self.default_config.items():
full_config[k] = v
#apply user defined config if provided
for k, v in self.user_config.items():
full_config[k] = v
full_config = DotMap(full_config)
return full_config
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
downsample = True if in_channels != out_channels else False
self.residual = nn.Sequential()
if downsample:
self.residual = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2),
nn.BatchNorm2d(out_channels)
)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=2 if downsample else 1)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out)) + self.residual(x)
return F.relu(out)
class ResBlockBottleNeck(nn.Module):
def __init__(self, in_channels, out_channels) -> None:
super().__init__()
downsample = True if in_channels != out_channels else False
self.residual = nn.Sequential()
if downsample:
self.residual = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2),
nn.BatchNorm2d(out_channels)
)
self.conv1 = nn.Conv2d(in_channels, out_channels//4, kernel_size=1, stride=1)
self.conv2 = nn.Conv2d(out_channels//4, out_channels//4, kernel_size=3, stride=2 if downsample else 1, padding=1)
self.conv3 = nn.Conv2d(out_channels//4, out_channels, kernel_size=1, stride=1)
self.bn1 = nn.BatchNorm2d(out_channels//4)
self.bn2 = nn.BatchNorm2d(out_channels//4)
self.bn3 = nn.BatchNorm2d(out_channels)
def forward(self, x):
out = F.relu((self.bn1(self.conv1(x))))
out = F.relu((self.bn2(self.conv2(out))))
out = F.relu((self.bn3(self.conv3(out)))) + self.residual(x)
return F.relu(out)
class ResNet(nn.Module):
def __init__(self, in_channels, out_size=1000, network_layers=18) -> None:
super().__init__()
base_chan = 64
network_config_dict = {
18: (False, (2, 2, 2, 2)),
34: (False, (3, 4, 6, 3)),
50: (True, (3, 4, 6, 3)),
101: (True, (3, 4, 23, 3)),
152: (True, (3, 4, 36, 3))
}
self.layerin = nn.Sequential(
nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.resblocks = nn.ModuleList()
network_config = network_config_dict[network_layers]
is_bottleneck = network_config[0]
curr_chan = base_chan
prev_chan = curr_chan
for i in network_config[1]:
for _ in range(i):
resblock = ResBlockBottleNeck(prev_chan, curr_chan) if is_bottleneck else ResBlock(prev_chan, curr_chan)
self.resblocks.append(resblock)
prev_chan = curr_chan
curr_chan *= 2
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(prev_chan, out_size)
def forward(self, x):
out = self.layerin(x)
for layer in self.resblocks:
out = layer(out)
out = self.avgpool(out)
out = out.view(out.size(0), -1)
return self.fc(out)
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from basedformer.utils import *
from basedformer.models import base_image
import einops
def _attn(query, key, value, causal_mask, masked_bias,
attention_mask=None, scale_attn=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
attn_weights = attn_weights / scale_attn
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
attn_output = torch.matmul(attn_weights, value).to(value.dtype)
return attn_output
class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later.
def __init__(self, config):
nn.Module.__init__(self)
max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool()
self.head_dim = config.hidden_dim // config.n_head
self.rotary_dim = self.head_dim // 4
self.hidden_dim = config.hidden_dim
self.n_head = config.n_head
device = config.device
dtype = config.dtype
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = False
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
def forward(self, x, kv=None, cache=False):
B, S, H = x.shape # batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
query = self.q_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
key = self.k_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
if kv:
k, v = kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
torch.cat([k, key], dim=-2) # cat key
torch.cat([v, value], dim=-2) # cat value
query_length, key_length = query.size(-2), key.size(-2) # seq_len, seq_len
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length]
x = _attn(
query, key, value, causal_mask, self.masked_bias, None, self.scale_attn
)
x = x.transpose(1, 2).contiguous().view(B, S, H)
x = self.out_proj(x)
if cache:
return x, (key, value)
else:
return x, None
class FeedForward(nn.Module):
def __init__(self, config):
super().__init__()
self.ff1 = nn.Linear(config.hidden_dim, config.hidden_dim*4, device=config.device, dtype=config.dtype)
self.ff2 = nn.Linear(config.hidden_dim*4, config.hidden_dim, device=config.device, dtype=config.dtype)
self.activation = config.activation
def forward(self, x):
x = self.ff1(x)
x = self.activation(x)
x = self.ff2(x)
return x
class ViTEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_dim = config.hidden_dim
self.ln_preattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.ln_postattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.ff = FeedForward(config)
self.attn = SelfAttention(config)
def forward(self, x):
residual = x
print(x.shape)
x = self.ln_preattn(x)
x = self.attn(x)[0]
x = residual + x
residual = x
x = self.ln_postattn(x)
x = self.ff(x)
return x + residual
class ViTEmbeds(nn.Module):
def __init__(self, config) -> None:
super().__init__()
p_size = config.patch_size
channels = config.channels
dim = config.hidden_dim
num_patches = (config.image_size[1] // p_size) * (config.image_size[0] // p_size)
self.lin_emb = nn.Linear((p_size ** 2) * channels, dim)
self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
self.pos = nn.Parameter(torch.zeros(1, num_patches + 1, dim))
def forward(self, x: torch.Tensor):
embed = self.lin_emb(x)
batch_size = x.size()[0]
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
embed = torch.cat((cls_tokens, embed), dim=1)
return embed + self.pos
class VisionTransformer(base_image.BaseVisionModel):
def __init__(self):
self.default_config = {
'n_layer': 12,
'n_head': 8,
'channels': 3,
'patch_size': 16,
'hidden_dim': 768,
'n_classes' : 1000,
'activation': gelu_new,
'image_size': (224, 224),
'eps': 1e-5,
'device': torch.device('cpu'),
'dtype': torch.float32,
}
super().__init__(self.default_config)
self.embed = ViTEmbeds(self.config)
self.encoder_layers = nn.ModuleList()
for _ in range(self.config.n_layer):
self.encoder_layers.append(ViTEncoder(self.config))
self.mlp_head = nn.Linear(self.config.hidden_dim, self.config.n_classes)
def forward(self, x):
p_size = self.config.patch_size
patches = einops.rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=p_size, s2=p_size)
print(patches.shape)
patches = self.embed(patches)
print(patches.shape)
for encoder in self.encoder_layers:
patches = encoder(patches)
return self.mlp_head(patches)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment