Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
f809c1e8
Commit
f809c1e8
authored
Jun 16, 2022
by
Arda Cihaner
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
ViT and ResNet
parent
c58dfef8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
275 additions
and
0 deletions
+275
-0
basedformer/models/base_image.py
basedformer/models/base_image.py
+25
-0
basedformer/models/resnet.py
basedformer/models/resnet.py
+89
-0
basedformer/models/vit.py
basedformer/models/vit.py
+161
-0
No files found.
basedformer/models/base_image.py
0 → 100644
View file @
f809c1e8
import
torch.nn
as
nn
from
dotmap
import
DotMap
class
BaseVisionModel
(
nn
.
Module
):
def
__init__
(
self
,
user_config
):
super
()
.
__init__
()
self
.
user_config
=
user_config
self
.
config
=
self
.
configure_model
()
config
=
self
.
config
def
configure_model
(
self
):
full_config
=
{}
if
not
hasattr
(
self
,
'default_config'
):
raise
ValueError
(
"No default config found, add one for the model to function"
)
#apply defaults
for
k
,
v
in
self
.
default_config
.
items
():
full_config
[
k
]
=
v
#apply user defined config if provided
for
k
,
v
in
self
.
user_config
.
items
():
full_config
[
k
]
=
v
full_config
=
DotMap
(
full_config
)
return
full_config
basedformer/models/resnet.py
0 → 100644
View file @
f809c1e8
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
class
ResBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
):
super
()
.
__init__
()
downsample
=
True
if
in_channels
!=
out_channels
else
False
self
.
residual
=
nn
.
Sequential
()
if
downsample
:
self
.
residual
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
2
),
nn
.
BatchNorm2d
(
out_channels
)
)
self
.
conv1
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
padding
=
1
,
stride
=
2
if
downsample
else
1
)
self
.
conv2
=
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
padding
=
1
)
self
.
bn1
=
nn
.
BatchNorm2d
(
out_channels
)
self
.
bn2
=
nn
.
BatchNorm2d
(
out_channels
)
def
forward
(
self
,
x
):
out
=
F
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
out
=
self
.
bn2
(
self
.
conv2
(
out
))
+
self
.
residual
(
x
)
return
F
.
relu
(
out
)
class
ResBlockBottleNeck
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
)
->
None
:
super
()
.
__init__
()
downsample
=
True
if
in_channels
!=
out_channels
else
False
self
.
residual
=
nn
.
Sequential
()
if
downsample
:
self
.
residual
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
2
),
nn
.
BatchNorm2d
(
out_channels
)
)
self
.
conv1
=
nn
.
Conv2d
(
in_channels
,
out_channels
//
4
,
kernel_size
=
1
,
stride
=
1
)
self
.
conv2
=
nn
.
Conv2d
(
out_channels
//
4
,
out_channels
//
4
,
kernel_size
=
3
,
stride
=
2
if
downsample
else
1
,
padding
=
1
)
self
.
conv3
=
nn
.
Conv2d
(
out_channels
//
4
,
out_channels
,
kernel_size
=
1
,
stride
=
1
)
self
.
bn1
=
nn
.
BatchNorm2d
(
out_channels
//
4
)
self
.
bn2
=
nn
.
BatchNorm2d
(
out_channels
//
4
)
self
.
bn3
=
nn
.
BatchNorm2d
(
out_channels
)
def
forward
(
self
,
x
):
out
=
F
.
relu
((
self
.
bn1
(
self
.
conv1
(
x
))))
out
=
F
.
relu
((
self
.
bn2
(
self
.
conv2
(
out
))))
out
=
F
.
relu
((
self
.
bn3
(
self
.
conv3
(
out
))))
+
self
.
residual
(
x
)
return
F
.
relu
(
out
)
class
ResNet
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_size
=
1000
,
network_layers
=
18
)
->
None
:
super
()
.
__init__
()
base_chan
=
64
network_config_dict
=
{
18
:
(
False
,
(
2
,
2
,
2
,
2
)),
34
:
(
False
,
(
3
,
4
,
6
,
3
)),
50
:
(
True
,
(
3
,
4
,
6
,
3
)),
101
:
(
True
,
(
3
,
4
,
23
,
3
)),
152
:
(
True
,
(
3
,
4
,
36
,
3
))
}
self
.
layerin
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
64
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
),
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
),
nn
.
BatchNorm2d
(
64
),
nn
.
ReLU
()
)
self
.
resblocks
=
nn
.
ModuleList
()
network_config
=
network_config_dict
[
network_layers
]
is_bottleneck
=
network_config
[
0
]
curr_chan
=
base_chan
prev_chan
=
curr_chan
for
i
in
network_config
[
1
]:
for
_
in
range
(
i
):
resblock
=
ResBlockBottleNeck
(
prev_chan
,
curr_chan
)
if
is_bottleneck
else
ResBlock
(
prev_chan
,
curr_chan
)
self
.
resblocks
.
append
(
resblock
)
prev_chan
=
curr_chan
curr_chan
*=
2
self
.
avgpool
=
nn
.
AdaptiveAvgPool2d
((
1
,
1
))
self
.
fc
=
nn
.
Linear
(
prev_chan
,
out_size
)
def
forward
(
self
,
x
):
out
=
self
.
layerin
(
x
)
for
layer
in
self
.
resblocks
:
out
=
layer
(
out
)
out
=
self
.
avgpool
(
out
)
out
=
out
.
view
(
out
.
size
(
0
),
-
1
)
return
self
.
fc
(
out
)
\ No newline at end of file
basedformer/models/vit.py
0 → 100644
View file @
f809c1e8
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
basedformer.utils
import
*
from
basedformer.models
import
base_image
import
einops
def
_attn
(
query
,
key
,
value
,
causal_mask
,
masked_bias
,
attention_mask
=
None
,
scale_attn
=
None
):
attn_weights
=
torch
.
matmul
(
query
,
key
.
transpose
(
-
1
,
-
2
))
attn_weights
=
torch
.
where
(
causal_mask
,
attn_weights
,
masked_bias
.
to
(
attn_weights
.
dtype
))
attn_weights
=
attn_weights
/
scale_attn
if
attention_mask
is
not
None
:
attn_weights
=
attn_weights
+
attention_mask
attn_weights
=
F
.
softmax
(
attn_weights
,
dim
=-
1
)
attn_weights
=
attn_weights
.
to
(
value
.
dtype
)
attn_output
=
torch
.
matmul
(
attn_weights
,
value
)
.
to
(
value
.
dtype
)
return
attn_output
class
SelfAttention
(
nn
.
Module
):
# Code copied from HF, might want to sanity check later.
def
__init__
(
self
,
config
):
nn
.
Module
.
__init__
(
self
)
max_positions
=
2049
bias
=
torch
.
tril
(
torch
.
ones
((
max_positions
,
max_positions
),
dtype
=
torch
.
uint8
,
requires_grad
=
False
))
.
view
(
1
,
1
,
max_positions
,
max_positions
)
.
bool
()
self
.
head_dim
=
config
.
hidden_dim
//
config
.
n_head
self
.
rotary_dim
=
self
.
head_dim
//
4
self
.
hidden_dim
=
config
.
hidden_dim
self
.
n_head
=
config
.
n_head
device
=
config
.
device
dtype
=
config
.
dtype
self
.
register_buffer
(
"scale_attn"
,
torch
.
sqrt
(
torch
.
tensor
(
self
.
head_dim
,
requires_grad
=
False
)
.
float
()))
self
.
register_buffer
(
"bias"
,
bias
)
self
.
register_buffer
(
"masked_bias"
,
torch
.
tensor
(
-
1e9
,
requires_grad
=
False
))
#-1e10 is what mtj uses.
attn_bias
=
False
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
out_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
def
forward
(
self
,
x
,
kv
=
None
,
cache
=
False
):
B
,
S
,
H
=
x
.
shape
# batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
query
=
self
.
q_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
key
=
self
.
k_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
value
=
self
.
v_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
if
kv
:
k
,
v
=
kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
torch
.
cat
([
k
,
key
],
dim
=-
2
)
# cat key
torch
.
cat
([
v
,
value
],
dim
=-
2
)
# cat value
query_length
,
key_length
=
query
.
size
(
-
2
),
key
.
size
(
-
2
)
# seq_len, seq_len
causal_mask
=
self
.
bias
[:,
:,
key_length
-
query_length
:
key_length
,
:
key_length
]
x
=
_attn
(
query
,
key
,
value
,
causal_mask
,
self
.
masked_bias
,
None
,
self
.
scale_attn
)
x
=
x
.
transpose
(
1
,
2
)
.
contiguous
()
.
view
(
B
,
S
,
H
)
x
=
self
.
out_proj
(
x
)
if
cache
:
return
x
,
(
key
,
value
)
else
:
return
x
,
None
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
()
.
__init__
()
self
.
ff1
=
nn
.
Linear
(
config
.
hidden_dim
,
config
.
hidden_dim
*
4
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
ff2
=
nn
.
Linear
(
config
.
hidden_dim
*
4
,
config
.
hidden_dim
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
activation
=
config
.
activation
def
forward
(
self
,
x
):
x
=
self
.
ff1
(
x
)
x
=
self
.
activation
(
x
)
x
=
self
.
ff2
(
x
)
return
x
class
ViTEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
()
.
__init__
()
self
.
hidden_dim
=
config
.
hidden_dim
self
.
ln_preattn
=
nn
.
LayerNorm
(
config
.
hidden_dim
,
eps
=
config
.
eps
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
ln_postattn
=
nn
.
LayerNorm
(
config
.
hidden_dim
,
eps
=
config
.
eps
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
ff
=
FeedForward
(
config
)
self
.
attn
=
SelfAttention
(
config
)
def
forward
(
self
,
x
):
residual
=
x
print
(
x
.
shape
)
x
=
self
.
ln_preattn
(
x
)
x
=
self
.
attn
(
x
)[
0
]
x
=
residual
+
x
residual
=
x
x
=
self
.
ln_postattn
(
x
)
x
=
self
.
ff
(
x
)
return
x
+
residual
class
ViTEmbeds
(
nn
.
Module
):
def
__init__
(
self
,
config
)
->
None
:
super
()
.
__init__
()
p_size
=
config
.
patch_size
channels
=
config
.
channels
dim
=
config
.
hidden_dim
num_patches
=
(
config
.
image_size
[
1
]
//
p_size
)
*
(
config
.
image_size
[
0
]
//
p_size
)
self
.
lin_emb
=
nn
.
Linear
((
p_size
**
2
)
*
channels
,
dim
)
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
dim
))
self
.
pos
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
+
1
,
dim
))
def
forward
(
self
,
x
:
torch
.
Tensor
):
embed
=
self
.
lin_emb
(
x
)
batch_size
=
x
.
size
()[
0
]
cls_tokens
=
self
.
cls_token
.
expand
(
batch_size
,
-
1
,
-
1
)
embed
=
torch
.
cat
((
cls_tokens
,
embed
),
dim
=
1
)
return
embed
+
self
.
pos
class
VisionTransformer
(
base_image
.
BaseVisionModel
):
def
__init__
(
self
):
self
.
default_config
=
{
'n_layer'
:
12
,
'n_head'
:
8
,
'channels'
:
3
,
'patch_size'
:
16
,
'hidden_dim'
:
768
,
'n_classes'
:
1000
,
'activation'
:
gelu_new
,
'image_size'
:
(
224
,
224
),
'eps'
:
1e-5
,
'device'
:
torch
.
device
(
'cpu'
),
'dtype'
:
torch
.
float32
,
}
super
()
.
__init__
(
self
.
default_config
)
self
.
embed
=
ViTEmbeds
(
self
.
config
)
self
.
encoder_layers
=
nn
.
ModuleList
()
for
_
in
range
(
self
.
config
.
n_layer
):
self
.
encoder_layers
.
append
(
ViTEncoder
(
self
.
config
))
self
.
mlp_head
=
nn
.
Linear
(
self
.
config
.
hidden_dim
,
self
.
config
.
n_classes
)
def
forward
(
self
,
x
):
p_size
=
self
.
config
.
patch_size
patches
=
einops
.
rearrange
(
x
,
'b c (h s1) (w s2) -> b (h w) (s1 s2 c)'
,
s1
=
p_size
,
s2
=
p_size
)
print
(
patches
.
shape
)
patches
=
self
.
embed
(
patches
)
print
(
patches
.
shape
)
for
encoder
in
self
.
encoder_layers
:
patches
=
encoder
(
patches
)
return
self
.
mlp_head
(
patches
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment