Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
482a7bae
Commit
482a7bae
authored
May 24, 2022
by
novelailab
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update
parent
b39e6d1b
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
855 additions
and
17 deletions
+855
-17
basedformer/lm_utils.py
basedformer/lm_utils.py
+1
-1
basedformer/models/__init__.py
basedformer/models/__init__.py
+4
-1
basedformer/models/alibi.py
basedformer/models/alibi.py
+203
-0
basedformer/models/fast.py
basedformer/models/fast.py
+296
-0
basedformer/models/gptj.py
basedformer/models/gptj.py
+15
-4
basedformer/models/perceiver.py
basedformer/models/perceiver.py
+265
-0
scripts/cudagraph.py
scripts/cudagraph.py
+1
-1
scripts/q_only.py
scripts/q_only.py
+54
-0
train.py
train.py
+16
-10
No files found.
basedformer/lm_utils.py
View file @
482a7bae
...
@@ -28,7 +28,7 @@ def init_weights(model, n_layer):
...
@@ -28,7 +28,7 @@ def init_weights(model, n_layer):
def
init
(
model_class
,
config
):
def
init
(
model_class
,
config
):
model
=
model_class
(
config
)
model
=
model_class
(
config
)
model
.
init_weights
(
)
init_weights
(
model
,
config
[
"n_layer"
]
)
return
model
return
model
def
no_init
(
model_class
,
config
):
def
no_init
(
model_class
,
config
):
...
...
basedformer/models/__init__.py
View file @
482a7bae
...
@@ -2,12 +2,15 @@ from . import gptj
...
@@ -2,12 +2,15 @@ from . import gptj
from
.
import
gpt2
from
.
import
gpt2
from
.
import
fairseq
from
.
import
fairseq
from
.
import
gptneo
from
.
import
gptneo
from
.
import
alibi
from
.
import
fast
MODEL_MAP
=
{
MODEL_MAP
=
{
"gptj"
:
gptj
.
GPTJModel
,
"gptj"
:
gptj
.
GPTJModel
,
"gpt2"
:
gpt2
.
GPT2Model
,
"gpt2"
:
gpt2
.
GPT2Model
,
"gpt-fairseq"
:
fairseq
.
GPTFairModel
,
"gpt-fairseq"
:
fairseq
.
GPTFairModel
,
"gpt-neo"
:
gptneo
.
GPTNeoModel
"gpt-neo"
:
gptneo
.
GPTNeoModel
,
"alibi"
:
alibi
.
AlibiModel
,
}
}
def
get_model
(
model_name
:
str
):
def
get_model
(
model_name
:
str
):
...
...
basedformer/models/alibi.py
0 → 100644
View file @
482a7bae
from
typing
import
Callable
,
KeysView
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
basedformer.utils
import
*
from
torch.utils.checkpoint
import
checkpoint
as
ck
from
einops
import
rearrange
,
repeat
try
:
from
collections.abc
import
MutableMapping
except
ImportError
:
from
collections
import
MutableMapping
import
os
from
pathlib
import
Path
import
math
from
basedformer.models
import
base_lm
def
get_slopes
(
n
):
def
get_slopes_power_of_2
(
n
):
start
=
(
2
**
(
-
2
**-
(
math
.
log2
(
n
)
-
3
)))
ratio
=
start
return
[
start
*
ratio
**
i
for
i
in
range
(
n
)]
if
math
.
log2
(
n
)
.
is_integer
():
return
get_slopes_power_of_2
(
n
)
#In the paper, we only train models that have 2^a heads for some a. This function has
else
:
#some good properties that only occur when the input is a power of 2. To maintain that even
closest_power_of_2
=
2
**
math
.
floor
(
math
.
log2
(
n
))
#when the number of heads is not a power of 2, we use this workaround.
return
get_slopes_power_of_2
(
closest_power_of_2
)
+
get_slopes
(
2
*
closest_power_of_2
)[
0
::
2
][:
n
-
closest_power_of_2
]
def
_attn
(
query
,
key
,
value
,
causal_mask
,
masked_bias
,
attention_mask
=
None
,
scale_attn
=
None
):
attn_weights
=
torch
.
matmul
(
query
,
key
.
transpose
(
-
1
,
-
2
))
attn_weights
=
torch
.
where
(
causal_mask
,
attn_weights
,
masked_bias
.
to
(
attn_weights
.
dtype
))
attn_weights
=
attn_weights
/
scale_attn
if
attention_mask
is
not
None
:
attn_weights
=
attn_weights
+
attention_mask
attn_weights
=
F
.
softmax
(
attn_weights
,
dim
=-
1
)
attn_weights
=
attn_weights
.
to
(
value
.
dtype
)
attn_output
=
torch
.
matmul
(
attn_weights
,
value
)
.
to
(
value
.
dtype
)
return
attn_output
class
SelfAttention
(
nn
.
Module
):
# Code copied from HF, might want to sanity check later.
def
__init__
(
self
,
config
):
nn
.
Module
.
__init__
(
self
)
max_positions
=
2049
bias
=
torch
.
tril
(
torch
.
ones
((
max_positions
,
max_positions
),
dtype
=
torch
.
uint8
,
requires_grad
=
False
))
.
view
(
1
,
1
,
max_positions
,
max_positions
)
.
bool
()
self
.
register_buffer
(
"slopes"
,
torch
.
Tensor
(
get_slopes
(
config
.
n_head
)))
#In the next line, the part after the * is what constructs the diagonal matrix (right matrix in Figure 3 in the paper).
#If you run it you'll see that it doesn't exactly print out the same matrix as we have in Figure 3, but one where all rows are identical.
#This works because the softmax operation is invariant to translation, and our bias functions are always linear.
print
(
self
.
slopes
.
shape
)
self
.
alibi
=
self
.
slopes
.
unsqueeze
(
1
)
.
unsqueeze
(
1
)
*
torch
.
arange
(
max_positions
)
.
unsqueeze
(
0
)
.
unsqueeze
(
0
)
.
expand
(
config
.
n_head
,
-
1
,
-
1
)
self
.
alibi
=
self
.
alibi
.
view
(
config
.
n_head
,
1
,
max_positions
)
self
.
head_dim
=
config
.
hidden_dim
//
config
.
n_head
self
.
rotary_dim
=
self
.
head_dim
//
4
self
.
hidden_dim
=
config
.
hidden_dim
self
.
n_head
=
config
.
n_head
self
.
q_only
=
config
.
q_only
self
.
register_buffer
(
"scale_attn"
,
torch
.
sqrt
(
torch
.
tensor
(
self
.
head_dim
,
requires_grad
=
False
)
.
float
()))
self
.
register_buffer
(
"bias"
,
bias
)
self
.
register_buffer
(
"masked_bias"
,
torch
.
tensor
(
-
1e9
,
requires_grad
=
False
))
#-1e10 is what mtj uses.
attn_bias
=
False
if
config
.
q_only
:
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
head_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
head_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
else
:
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
out_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
def
forward
(
self
,
x
,
kv
=
None
,
cache
=
False
):
B
,
S
,
H
=
x
.
shape
# batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
# transpose q, k after rotary as rotary code accepts [b, s, h, h_d]
query
=
self
.
q_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
if
self
.
q_only
:
key
=
self
.
k_proj
(
x
)
.
view
(
B
,
S
,
1
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
value
=
self
.
v_proj
(
x
)
.
view
(
B
,
S
,
1
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
else
:
key
=
self
.
k_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
value
=
self
.
v_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
if
kv
:
offset
=
kv
[
0
]
.
shape
[
-
2
]
else
:
offset
=
0
self
.
alibi
=
self
.
alibi
.
repeat
(
B
,
1
,
1
)
# batch_size, 1, 1
if
kv
:
k
,
v
=
kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
key
=
torch
.
cat
([
k
,
key
],
dim
=-
2
)
# cat key
value
=
torch
.
cat
([
v
,
value
],
dim
=-
2
)
# cat value
query_length
,
key_length
=
query
.
size
(
-
2
),
key
.
size
(
-
2
)
#causal mask with generation in mind
causal_mask
=
self
.
bias
[:,
:,
key_length
-
query_length
:
key_length
,
:
key_length
]
print
(
causal_mask
.
shape
)
print
(
self
.
alibi
.
shape
)
x
=
_attn
(
query
,
key
,
value
,
causal_mask
+
self
.
alibi
,
self
.
masked_bias
,
None
,
self
.
scale_attn
)
x
=
x
.
transpose
(
1
,
2
)
.
contiguous
()
.
view
(
B
,
S
,
H
)
x
=
self
.
out_proj
(
x
)
if
cache
:
return
x
,
(
key
,
value
)
else
:
return
x
,
None
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
config
):
nn
.
Module
.
__init__
(
self
)
self
.
ff1
=
nn
.
Linear
(
config
.
hidden_dim
,
config
.
hidden_dim
*
4
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
ff2
=
nn
.
Linear
(
config
.
hidden_dim
*
4
,
config
.
hidden_dim
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
activation
=
config
.
activation
def
forward
(
self
,
x
,
act_ck
=
False
):
x
=
self
.
ff1
(
x
)
if
act_ck
:
x
=
ck
(
self
.
activation
,
x
)
else
:
x
=
self
.
activation
(
x
)
x
=
self
.
ff2
(
x
)
return
x
class
AlibiLayer
(
nn
.
Module
):
def
__init__
(
self
,
attn
,
ff
,
config
):
nn
.
Module
.
__init__
(
self
)
self
.
ln_preattn
=
nn
.
LayerNorm
(
config
.
hidden_dim
,
eps
=
config
.
eps
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
#self.ln_preattn = nn.LogSoftmax(dim=-2)
self
.
ff
=
ff
(
config
)
self
.
attn
=
attn
(
config
)
self
.
tick
=
True
def
forward
(
self
,
x
,
layer_id
=
None
,
hypernetwork
=
None
,
act_ck
=
False
,
diff_hypernets
=
False
,
interleaving_layers
=
False
,
every_n
=
5
,
cache
=
False
,
kv
=
None
):
residual
=
x
if
act_ck
:
x
=
ck
(
self
.
ln_preattn
,
x
)
attn_out
,
kv
=
ck
(
self
.
attn
,
x
,
kv
,
cache
)
#attn_out, kv = self.attn(x, kv=kv, cache=cache)
else
:
x
=
self
.
ln_preattn
(
x
)
attn_out
,
kv
=
self
.
attn
(
x
,
kv
=
kv
,
cache
=
cache
)
if
hypernetwork
:
if
diff_hypernets
:
if
interleaving_layers
and
layer_id
%
every_n
==
0
:
if
self
.
tick
:
hyper_out
=
hypernetwork
[
0
](
x
)
self
.
tick
=
False
else
:
hyper_out
=
hypernetwork
[
1
](
x
)
self
.
tick
=
True
elif
layer_id
%
every_n
==
0
:
hyper_out
=
hypernetwork
[(
layer_id
//
every_n
)
-
1
](
x
)
else
:
if
layer_id
%
every_n
==
0
:
hyper_out
=
hypernetwork
(
x
)
ff_out
=
self
.
ff
(
x
,
act_ck
)
#order of addition matters, i had no idea... fixed a bug here.
x
=
attn_out
+
ff_out
+
residual
#x = residual + attn_out + ff_out -> doesn't match.
if
hypernetwork
and
layer_id
%
every_n
==
0
:
x
=
x
+
hyper_out
return
x
,
kv
class
AlibiModel
(
base_lm
.
BaseModel
):
def
__init__
(
self
,
user_config
,
**
kwargs
):
self
.
default_config
=
{
'n_layer'
:
6
,
'n_head'
:
8
,
'n_tokens'
:
2048
,
'hidden_dim'
:
512
,
'vocab_dim'
:
50400
,
'eps'
:
1e-5
,
'device'
:
torch
.
device
(
'cuda'
),
'dtype'
:
torch
.
float16
,
'Layer'
:
AlibiLayer
,
'activation'
:
gelu_new
,
'SelfAttention'
:
SelfAttention
,
'FeedForward'
:
FeedForward
,
}
base_lm
.
BaseModel
.
__init__
(
self
,
user_config
,
**
kwargs
)
basedformer/models/fast.py
0 → 100644
View file @
482a7bae
from
typing
import
Callable
,
KeysView
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
basedformer.utils
import
*
from
torch.utils.checkpoint
import
checkpoint
as
ck
from
einops
import
rearrange
,
repeat
try
:
from
collections.abc
import
MutableMapping
except
ImportError
:
from
collections
import
MutableMapping
import
os
from
pathlib
import
Path
import
math
from
basedformer.models
import
base_lm
def
fixed_pos_embedding
(
dim
=
None
,
seq_len
=
None
,
x
=
None
):
if
x
is
None
:
x
=
torch
.
empty
(
0
)
inv_freq
=
1.
/
(
10000
**
(
torch
.
arange
(
0
,
dim
,
2
)
/
dim
))
.
to
(
x
.
dtype
)
.
to
(
x
.
device
)
sinusoid_inp
=
torch
.
einsum
(
'i , j -> i j'
,
torch
.
arange
(
seq_len
)
.
to
(
x
.
device
),
inv_freq
)
.
float
()
return
torch
.
sin
(
sinusoid_inp
),
torch
.
cos
(
sinusoid_inp
)
def
rotate_every_two
(
x
):
x1
=
x
[:,
:,
:,
::
2
]
x2
=
x
[:,
:,
:,
1
::
2
]
x
=
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
)
return
rearrange
(
x
,
'... d j -> ... (d j)'
)
def
apply_rotary_pos_emb
(
x
,
sincos
,
offset
=
0
):
sin
,
cos
=
map
(
lambda
t
:
repeat
(
t
[
offset
:
x
.
shape
[
1
]
+
offset
,:],
"n d -> () n () (d j)"
,
j
=
2
),
sincos
)
return
(
x
*
cos
)
+
(
rotate_every_two
(
x
)
*
sin
)
def
token_shift_fast
(
x
,
n_tokens
=
1
):
size
=
x
.
size
()[
-
1
]
//
(
n_tokens
+
1
)
seq_len
=
x
.
size
()[
-
2
]
padded_x
=
nn
.
functional
.
pad
(
x
[:,
:,
-
size
:],
(
0
,
0
,
n_tokens
,
0
))
token_shifts
=
[
padded_x
[:,
offset
:(
offset
+
seq_len
)]
for
offset
in
range
(
n_tokens
)]
current_x
=
[
x
[:,
:,
len
(
token_shifts
)
*
size
:]]
x
=
torch
.
cat
(
token_shifts
+
current_x
,
dim
=-
1
)
return
x
def
_attn
(
query
,
key
,
value
,
causal_mask
,
masked_bias
,
attention_mask
=
None
,
scale_attn
=
None
):
attn_weights
=
torch
.
matmul
(
query
,
key
.
transpose
(
-
1
,
-
2
))
attn_weights
=
torch
.
where
(
causal_mask
,
attn_weights
,
masked_bias
.
to
(
attn_weights
.
dtype
))
attn_weights
=
attn_weights
/
scale_attn
if
attention_mask
is
not
None
:
attn_weights
=
attn_weights
+
attention_mask
attn_weights
=
F
.
softmax
(
attn_weights
,
dim
=-
1
)
attn_weights
=
attn_weights
.
to
(
value
.
dtype
)
attn_output
=
torch
.
matmul
(
attn_weights
,
value
)
.
to
(
value
.
dtype
)
return
attn_output
class
SelfAttention
(
nn
.
Module
):
# Code copied from HF, might want to sanity check later.
def
__init__
(
self
,
config
,
small_attn
=
False
):
nn
.
Module
.
__init__
(
self
)
max_positions
=
2049
bias
=
torch
.
tril
(
torch
.
ones
((
max_positions
,
max_positions
),
dtype
=
torch
.
uint8
,
requires_grad
=
False
))
.
view
(
1
,
1
,
max_positions
,
max_positions
)
.
bool
()
self
.
head_dim
=
config
.
hidden_dim
//
config
.
n_head
self
.
rotary_dim
=
self
.
head_dim
//
4
self
.
hidden_dim
=
config
.
hidden_dim
self
.
n_head
=
config
.
n_head
self
.
q_only
=
config
.
q_only
self
.
small_attn
=
small_attn
self
.
register_buffer
(
"scale_attn"
,
torch
.
sqrt
(
torch
.
tensor
(
self
.
head_dim
,
requires_grad
=
False
)
.
float
()))
self
.
register_buffer
(
"bias"
,
bias
)
self
.
register_buffer
(
"masked_bias"
,
torch
.
tensor
(
-
1e9
,
requires_grad
=
False
))
#-1e10 is what mtj uses.
attn_bias
=
False
if
config
.
q_only
:
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
head_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
head_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
else
:
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
if
small_attn
:
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
head_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
else
:
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
if
small_attn
:
self
.
out_proj
=
nn
.
Linear
(
self
.
head_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
else
:
self
.
out_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
sin
,
cos
=
fixed_pos_embedding
(
dim
=
self
.
rotary_dim
,
seq_len
=
max_positions
)
self
.
register_buffer
(
"sin"
,
sin
)
self
.
register_buffer
(
"cos"
,
cos
)
def
forward
(
self
,
x
,
kv
=
None
,
cache
=
False
):
B
,
S
,
H
=
x
.
shape
# batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
# transpose q, k after rotary as rotary code accepts [b, s, h, h_d]
if
self
.
small_attn
:
query
=
self
.
q_proj
(
x
)
.
view
(
B
,
S
,
1
,
self
.
head_dim
)
else
:
query
=
self
.
q_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
if
self
.
q_only
:
key
=
self
.
k_proj
(
x
)
.
view
(
B
,
S
,
1
,
self
.
head_dim
)
value
=
self
.
v_proj
(
x
)
.
view
(
B
,
S
,
1
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
else
:
key
=
self
.
k_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
value
=
self
.
v_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
if
kv
:
offset
=
kv
[
0
]
.
shape
[
-
2
]
else
:
offset
=
0
if
self
.
rotary_dim
<
self
.
head_dim
:
k_rot
=
key
[:,
:,
:,
:
self
.
rotary_dim
]
k_pass
=
key
[:,
:,
:,
self
.
rotary_dim
:]
q_rot
=
query
[:,
:,
:,
:
self
.
rotary_dim
]
q_pass
=
query
[:,
:,
:,
self
.
rotary_dim
:]
k_rot
=
apply_rotary_pos_emb
(
k_rot
,
(
self
.
sin
,
self
.
cos
),
offset
=
offset
)
.
to
(
k_rot
.
dtype
)
q_rot
=
apply_rotary_pos_emb
(
q_rot
,
(
self
.
sin
,
self
.
cos
),
offset
=
offset
)
.
to
(
q_rot
.
dtype
)
key
=
torch
.
cat
([
k_rot
,
k_pass
],
dim
=-
1
)
query
=
torch
.
cat
([
q_rot
,
q_pass
],
dim
=-
1
)
else
:
key
=
apply_rotary_pos_emb
(
key
,
(
self
.
sin
,
self
.
cos
),
offset
=
offset
)
.
to
(
key
.
dtype
)
query
=
apply_rotary_pos_emb
(
query
,
(
self
.
sin
,
self
.
cos
),
offset
=
offset
)
.
to
(
query
.
dtype
)
query
=
query
.
transpose
(
1
,
2
)
key
=
key
.
transpose
(
1
,
2
)
if
kv
:
k
,
v
=
kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
key
=
torch
.
cat
([
k
,
key
],
dim
=-
2
)
# cat key
value
=
torch
.
cat
([
v
,
value
],
dim
=-
2
)
# cat value
query_length
,
key_length
=
query
.
size
(
-
2
),
key
.
size
(
-
2
)
#causal mask with generation in mind
causal_mask
=
self
.
bias
[:,
:,
key_length
-
query_length
:
key_length
,
:
key_length
]
x
=
_attn
(
query
,
key
,
value
,
causal_mask
,
self
.
masked_bias
,
None
,
self
.
scale_attn
)
x
=
x
.
transpose
(
1
,
2
)
.
contiguous
()
.
view
(
B
,
S
,
H
)
x
=
self
.
out_proj
(
x
)
if
cache
:
return
x
,
(
key
,
value
)
else
:
return
x
,
None
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
config
):
nn
.
Module
.
__init__
(
self
)
self
.
config
=
config
self
.
ff1
=
nn
.
Linear
(
config
.
hidden_dim
,
config
.
hidden_dim
*
4
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
ff2
=
nn
.
Linear
(
config
.
hidden_dim
*
4
,
config
.
hidden_dim
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
activation
=
config
.
activation
def
forward
(
self
,
x
,
act_ck
=
False
):
x
=
self
.
ff1
(
x
)
if
act_ck
:
x
=
ck
(
self
.
activation
,
x
)
else
:
x
=
self
.
activation
(
x
)
if
self
.
config
.
token_shift
:
x
=
token_shift_fast
(
x
)
x
=
self
.
ff2
(
x
)
return
x
class
GPTJLayer
(
nn
.
Module
):
def
__init__
(
self
,
attn
,
ff
,
config
):
nn
.
Module
.
__init__
(
self
)
self
.
config
=
config
self
.
ln_preattn
=
nn
.
LayerNorm
(
config
.
hidden_dim
,
eps
=
config
.
eps
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
#self.ln_preattn = nn.LogSoftmax(dim=-2)
self
.
ff
=
ff
(
config
)
self
.
attn
=
attn
(
config
)
self
.
tick
=
True
def
forward
(
self
,
x
,
layer_id
=
None
,
hypernetwork
=
None
,
act_ck
=
False
,
diff_hypernets
=
False
,
interleaving_layers
=
False
,
every_n
=
5
,
cache
=
False
,
kv
=
None
):
residual
=
x
if
act_ck
:
x
=
ck
(
self
.
ln_preattn
,
x
)
attn_out
,
kv
=
ck
(
self
.
attn
,
x
,
kv
,
cache
)
#attn_out, kv = self.attn(x, kv=kv, cache=cache)
else
:
x
=
self
.
ln_preattn
(
x
)
attn_out
,
kv
=
self
.
attn
(
x
,
kv
=
kv
,
cache
=
cache
)
if
hypernetwork
:
if
diff_hypernets
:
if
interleaving_layers
and
layer_id
%
every_n
==
0
:
if
self
.
tick
:
hyper_out
=
hypernetwork
[
0
](
x
)
self
.
tick
=
False
else
:
hyper_out
=
hypernetwork
[
1
](
x
)
self
.
tick
=
True
elif
layer_id
%
every_n
==
0
:
hyper_out
=
hypernetwork
[(
layer_id
//
every_n
)
-
1
](
x
)
else
:
if
layer_id
%
every_n
==
0
:
hyper_out
=
hypernetwork
(
x
)
ff_out
=
self
.
ff
(
x
,
act_ck
)
#order of addition matters, i had no idea... fixed a bug here.
x
=
attn_out
+
ff_out
+
residual
#x = residual + attn_out + ff_out -> doesn't match.
if
hypernetwork
and
layer_id
%
every_n
==
0
:
x
=
x
+
hyper_out
return
x
,
kv
class
GPTJnoattnLayer
(
nn
.
Module
):
def
__init__
(
self
,
attn
,
ff
,
config
):
nn
.
Module
.
__init__
(
self
)
self
.
config
=
config
self
.
ln_preattn
=
nn
.
LayerNorm
(
config
.
hidden_dim
,
eps
=
config
.
eps
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
ff
=
ff
(
config
)
self
.
attn
=
attn
(
config
,
small_attn
=
True
)
def
forward
(
self
,
x
,
layer_id
=
None
,
hypernetwork
=
None
,
act_ck
=
False
,
diff_hypernets
=
False
,
interleaving_layers
=
False
,
every_n
=
5
,
cache
=
False
,
kv
=
None
):
residual
=
x
if
act_ck
:
x
=
ck
(
self
.
ln_preattn
,
x
)
attn_out
,
kv
=
ck
(
self
.
attn
,
x
,
kv
,
cache
)
else
:
x
=
self
.
ln_preattn
(
x
)
attn_out
,
kv
=
self
.
attn
(
x
,
kv
=
kv
,
cache
=
cache
)
ff_out
=
self
.
ff
(
x
,
act_ck
)
#order of addition matters, i had no idea... fixed a bug here.
x
=
attn_out
+
ff_out
+
residual
#x = residual + attn_out + ff_out -> doesn't match.
return
x
,
kv
class
GPTJModel
(
base_lm
.
BaseModel
):
def
__init__
(
self
,
user_config
,
**
kwargs
):
nn
.
Module
.
__init__
(
self
)
self
.
default_config
=
{
'n_layer'
:
6
,
'n_head'
:
8
,
'n_tokens'
:
2048
,
'hidden_dim'
:
512
,
'vocab_dim'
:
50400
,
'eps'
:
1e-5
,
'device'
:
torch
.
device
(
'cuda'
),
'dtype'
:
torch
.
float16
,
'Layer'
:
GPTJLayer
,
'AlternateLayer'
:
GPTJnoattnLayer
,
'activation'
:
gelu_new
,
'SelfAttention'
:
SelfAttention
,
'FeedForward'
:
FeedForward
,
}
#configuration
self
.
user_config
=
user_config
self
.
config
=
self
.
configure_model
()
config
=
self
.
config
#modeling
self
.
n_layer
=
config
.
n_layer
self
.
hidden_dim
=
config
.
hidden_dim
self
.
vocab_embed
=
nn
.
Embedding
(
config
.
vocab_dim
,
self
.
hidden_dim
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
ln_final
=
nn
.
LayerNorm
(
self
.
hidden_dim
,
eps
=
config
.
eps
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
layers
=
nn
.
ModuleList
([])
self
.
lm_head
=
nn
.
Linear
(
config
.
hidden_dim
,
config
.
vocab_dim
,
bias
=
True
)
for
i
in
range
(
config
.
n_layer
//
2
):
config
.
layer_idx
=
i
self
.
layers
.
append
(
config
.
Layer
(
attn
=
config
.
SelfAttention
,
ff
=
config
.
FeedForward
,
config
=
config
,
)
)
self
.
layers
.
append
(
config
.
AlternateLayer
(
attn
=
config
.
SelfAttention
,
ff
=
config
.
FeedForward
,
config
=
config
,
)
)
basedformer/models/gptj.py
View file @
482a7bae
...
@@ -59,12 +59,18 @@ class SelfAttention(nn.Module):
...
@@ -59,12 +59,18 @@ class SelfAttention(nn.Module):
self
.
rotary_dim
=
self
.
head_dim
//
4
self
.
rotary_dim
=
self
.
head_dim
//
4
self
.
hidden_dim
=
config
.
hidden_dim
self
.
hidden_dim
=
config
.
hidden_dim
self
.
n_head
=
config
.
n_head
self
.
n_head
=
config
.
n_head
self
.
q_only
=
config
.
q_only
self
.
register_buffer
(
"scale_attn"
,
torch
.
sqrt
(
torch
.
tensor
(
self
.
head_dim
,
requires_grad
=
False
)
.
float
()))
self
.
register_buffer
(
"scale_attn"
,
torch
.
sqrt
(
torch
.
tensor
(
self
.
head_dim
,
requires_grad
=
False
)
.
float
()))
self
.
register_buffer
(
"bias"
,
bias
)
self
.
register_buffer
(
"bias"
,
bias
)
self
.
register_buffer
(
"masked_bias"
,
torch
.
tensor
(
-
1e9
,
requires_grad
=
False
))
#-1e10 is what mtj uses.
self
.
register_buffer
(
"masked_bias"
,
torch
.
tensor
(
-
1e9
,
requires_grad
=
False
))
#-1e10 is what mtj uses.
attn_bias
=
False
attn_bias
=
False
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
if
config
.
q_only
:
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
head_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
head_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
else
:
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
out_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
out_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
sin
,
cos
=
fixed_pos_embedding
(
dim
=
self
.
rotary_dim
,
seq_len
=
max_positions
)
sin
,
cos
=
fixed_pos_embedding
(
dim
=
self
.
rotary_dim
,
seq_len
=
max_positions
)
...
@@ -76,8 +82,12 @@ class SelfAttention(nn.Module):
...
@@ -76,8 +82,12 @@ class SelfAttention(nn.Module):
# split heads into: [batch, head, sequence, head_dim]
# split heads into: [batch, head, sequence, head_dim]
# transpose q, k after rotary as rotary code accepts [b, s, h, h_d]
# transpose q, k after rotary as rotary code accepts [b, s, h, h_d]
query
=
self
.
q_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
query
=
self
.
q_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
key
=
self
.
k_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
if
self
.
q_only
:
value
=
self
.
v_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
key
=
self
.
k_proj
(
x
)
.
view
(
B
,
S
,
1
,
self
.
head_dim
)
value
=
self
.
v_proj
(
x
)
.
view
(
B
,
S
,
1
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
else
:
key
=
self
.
k_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
value
=
self
.
v_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
if
kv
:
if
kv
:
offset
=
kv
[
0
]
.
shape
[
-
2
]
offset
=
kv
[
0
]
.
shape
[
-
2
]
...
@@ -147,6 +157,7 @@ class GPTJLayer(nn.Module):
...
@@ -147,6 +157,7 @@ class GPTJLayer(nn.Module):
def
__init__
(
self
,
attn
,
ff
,
config
):
def
__init__
(
self
,
attn
,
ff
,
config
):
nn
.
Module
.
__init__
(
self
)
nn
.
Module
.
__init__
(
self
)
self
.
ln_preattn
=
nn
.
LayerNorm
(
config
.
hidden_dim
,
eps
=
config
.
eps
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
ln_preattn
=
nn
.
LayerNorm
(
config
.
hidden_dim
,
eps
=
config
.
eps
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
#self.ln_preattn = nn.LogSoftmax(dim=-2)
self
.
ff
=
ff
(
config
)
self
.
ff
=
ff
(
config
)
self
.
attn
=
attn
(
config
)
self
.
attn
=
attn
(
config
)
self
.
tick
=
True
self
.
tick
=
True
...
...
basedformer/models/perceiver.py
0 → 100644
View file @
482a7bae
from
typing
import
Callable
,
KeysView
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
basedformer.utils
import
*
from
torch.utils.checkpoint
import
checkpoint
as
ck
from
einops
import
rearrange
,
repeat
try
:
from
collections.abc
import
MutableMapping
except
ImportError
:
from
collections
import
MutableMapping
import
os
from
pathlib
import
Path
import
math
from
basedformer.models
import
base_lm
def
fixed_pos_embedding
(
dim
=
None
,
seq_len
=
None
,
x
=
None
):
if
x
is
None
:
x
=
torch
.
empty
(
0
)
inv_freq
=
1.
/
(
10000
**
(
torch
.
arange
(
0
,
dim
,
2
)
/
dim
))
.
to
(
x
.
dtype
)
.
to
(
x
.
device
)
sinusoid_inp
=
torch
.
einsum
(
'i , j -> i j'
,
torch
.
arange
(
seq_len
)
.
to
(
x
.
device
),
inv_freq
)
.
float
()
return
torch
.
sin
(
sinusoid_inp
),
torch
.
cos
(
sinusoid_inp
)
def
rotate_every_two
(
x
):
x1
=
x
[:,
:,
:,
::
2
]
x2
=
x
[:,
:,
:,
1
::
2
]
x
=
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
)
return
rearrange
(
x
,
'... d j -> ... (d j)'
)
def
apply_rotary_pos_emb
(
x
,
sincos
,
offset
=
0
):
sin
,
cos
=
map
(
lambda
t
:
repeat
(
t
[
offset
:
x
.
shape
[
1
]
+
offset
,:],
"n d -> () n () (d j)"
,
j
=
2
),
sincos
)
return
(
x
*
cos
)
+
(
rotate_every_two
(
x
)
*
sin
)
def
_attn
(
query
,
key
,
value
,
causal_mask
,
masked_bias
,
attention_mask
=
None
,
scale_attn
=
None
):
attn_weights
=
torch
.
matmul
(
query
,
key
.
transpose
(
-
1
,
-
2
))
attn_weights
=
torch
.
where
(
causal_mask
,
attn_weights
,
masked_bias
.
to
(
attn_weights
.
dtype
))
attn_weights
=
attn_weights
/
scale_attn
if
attention_mask
is
not
None
:
attn_weights
=
attn_weights
+
attention_mask
attn_weights
=
F
.
softmax
(
attn_weights
,
dim
=-
1
)
attn_weights
=
attn_weights
.
to
(
value
.
dtype
)
attn_output
=
torch
.
matmul
(
attn_weights
,
value
)
.
to
(
value
.
dtype
)
return
attn_output
class
CrossAttentionMod
(
nn
.
Module
):
# Code copied from HF, might want to sanity check later.
def
__init__
(
self
,
hidden_dim
,
n_head
,
device
,
dtype
):
nn
.
Module
.
__init__
(
self
)
max_positions
=
2049
bias
=
torch
.
tril
(
torch
.
ones
((
max_positions
,
max_positions
),
dtype
=
torch
.
uint8
,
requires_grad
=
False
))
.
view
(
1
,
1
,
max_positions
,
max_positions
)
.
bool
()
self
.
head_dim
=
hidden_dim
//
n_head
self
.
rotary_dim
=
self
.
head_dim
//
4
self
.
hidden_dim
=
hidden_dim
self
.
n_head
=
n_head
self
.
register_buffer
(
"scale_attn"
,
torch
.
sqrt
(
torch
.
tensor
(
self
.
head_dim
,
requires_grad
=
False
)
.
float
()))
self
.
register_buffer
(
"bias"
,
bias
)
self
.
register_buffer
(
"masked_bias"
,
torch
.
tensor
(
-
1e9
,
requires_grad
=
False
))
#-1e10 is what mtj uses.
attn_bias
=
False
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
out_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
def
forward
(
self
,
x
,
y
,
kv
=
None
,
cache
=
False
):
B
,
S
,
H
=
x
.
shape
# batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
query
=
self
.
q_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
#latent query
key
=
self
.
k_proj
(
y
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
#context key
value
=
self
.
v_proj
(
y
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
#context value
if
kv
:
k
,
v
=
kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
torch
.
cat
([
k
,
key
],
dim
=-
2
)
# cat key
torch
.
cat
([
v
,
value
],
dim
=-
2
)
# cat value
query_length
,
key_length
=
query
.
size
(
-
2
),
key
.
size
(
-
2
)
# seq_len, seq_len
causal_mask
=
self
.
bias
[:,
:,
key_length
-
query_length
:
key_length
,
:
key_length
]
x
=
_attn
(
query
,
key
,
value
,
causal_mask
,
self
.
masked_bias
,
None
,
self
.
scale_attn
)
x
=
x
.
transpose
(
1
,
2
)
.
contiguous
()
.
view
(
B
,
S
,
H
)
x
=
self
.
out_proj
(
x
)
if
cache
:
return
x
,
(
key
,
value
)
else
:
return
x
class
SelfAttention
(
nn
.
Module
):
# Code copied from HF, might want to sanity check later.
def
__init__
(
self
,
config
):
nn
.
Module
.
__init__
(
self
)
max_positions
=
2049
bias
=
torch
.
tril
(
torch
.
ones
((
max_positions
,
max_positions
),
dtype
=
torch
.
uint8
,
requires_grad
=
False
))
.
view
(
1
,
1
,
max_positions
,
max_positions
)
.
bool
()
self
.
head_dim
=
config
.
hidden_dim
//
config
.
n_head
self
.
rotary_dim
=
self
.
head_dim
//
4
self
.
hidden_dim
=
config
.
hidden_dim
self
.
n_head
=
config
.
n_head
self
.
q_only
=
config
.
q_only
self
.
register_buffer
(
"scale_attn"
,
torch
.
sqrt
(
torch
.
tensor
(
self
.
head_dim
,
requires_grad
=
False
)
.
float
()))
self
.
register_buffer
(
"bias"
,
bias
)
self
.
register_buffer
(
"masked_bias"
,
torch
.
tensor
(
-
1e9
,
requires_grad
=
False
))
#-1e10 is what mtj uses.
attn_bias
=
False
if
config
.
q_only
:
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
head_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
head_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
else
:
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
out_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
sin
,
cos
=
fixed_pos_embedding
(
dim
=
self
.
rotary_dim
,
seq_len
=
max_positions
)
self
.
register_buffer
(
"sin"
,
sin
)
self
.
register_buffer
(
"cos"
,
cos
)
def
forward
(
self
,
x
,
kv
=
None
,
cache
=
False
):
B
,
S
,
H
=
x
.
shape
# batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
# transpose q, k after rotary as rotary code accepts [b, s, h, h_d]
query
=
self
.
q_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
if
self
.
q_only
:
key
=
self
.
k_proj
(
x
)
.
view
(
B
,
S
,
1
,
self
.
head_dim
)
value
=
self
.
v_proj
(
x
)
.
view
(
B
,
S
,
1
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
else
:
key
=
self
.
k_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
value
=
self
.
v_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
if
kv
:
offset
=
kv
[
0
]
.
shape
[
-
2
]
else
:
offset
=
0
if
self
.
rotary_dim
<
self
.
head_dim
:
k_rot
=
key
[:,
:,
:,
:
self
.
rotary_dim
]
k_pass
=
key
[:,
:,
:,
self
.
rotary_dim
:]
q_rot
=
query
[:,
:,
:,
:
self
.
rotary_dim
]
q_pass
=
query
[:,
:,
:,
self
.
rotary_dim
:]
k_rot
=
apply_rotary_pos_emb
(
k_rot
,
(
self
.
sin
,
self
.
cos
),
offset
=
offset
)
.
to
(
k_rot
.
dtype
)
q_rot
=
apply_rotary_pos_emb
(
q_rot
,
(
self
.
sin
,
self
.
cos
),
offset
=
offset
)
.
to
(
q_rot
.
dtype
)
key
=
torch
.
cat
([
k_rot
,
k_pass
],
dim
=-
1
)
query
=
torch
.
cat
([
q_rot
,
q_pass
],
dim
=-
1
)
else
:
key
=
apply_rotary_pos_emb
(
key
,
(
self
.
sin
,
self
.
cos
),
offset
=
offset
)
.
to
(
key
.
dtype
)
query
=
apply_rotary_pos_emb
(
query
,
(
self
.
sin
,
self
.
cos
),
offset
=
offset
)
.
to
(
query
.
dtype
)
query
=
query
.
transpose
(
1
,
2
)
key
=
key
.
transpose
(
1
,
2
)
if
kv
:
k
,
v
=
kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
key
=
torch
.
cat
([
k
,
key
],
dim
=-
2
)
# cat key
value
=
torch
.
cat
([
v
,
value
],
dim
=-
2
)
# cat value
query_length
,
key_length
=
query
.
size
(
-
2
),
key
.
size
(
-
2
)
#causal mask with generation in mind
causal_mask
=
self
.
bias
[:,
:,
key_length
-
query_length
:
key_length
,
:
key_length
]
x
=
_attn
(
query
,
key
,
value
,
causal_mask
,
self
.
masked_bias
,
None
,
self
.
scale_attn
)
x
=
x
.
transpose
(
1
,
2
)
.
contiguous
()
.
view
(
B
,
S
,
H
)
x
=
self
.
out_proj
(
x
)
if
cache
:
return
x
,
(
key
,
value
)
else
:
return
x
,
None
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
config
):
nn
.
Module
.
__init__
(
self
)
self
.
ff1
=
nn
.
Linear
(
config
.
hidden_dim
,
config
.
hidden_dim
*
4
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
ff2
=
nn
.
Linear
(
config
.
hidden_dim
*
4
,
config
.
hidden_dim
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
activation
=
config
.
activation
def
forward
(
self
,
x
,
act_ck
=
False
):
x
=
self
.
ff1
(
x
)
if
act_ck
:
x
=
ck
(
self
.
activation
,
x
)
else
:
x
=
self
.
activation
(
x
)
x
=
self
.
ff2
(
x
)
return
x
class
PerceiverARLayer
(
nn
.
Module
):
def
__init__
(
self
,
attn
,
ff
,
config
):
nn
.
Module
.
__init__
(
self
)
self
.
ln_preattn
=
nn
.
LayerNorm
(
config
.
hidden_dim
,
eps
=
config
.
eps
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
ff
=
ff
(
config
)
self
.
attn
=
attn
(
config
)
self
.
tick
=
True
def
forward
(
self
,
x
,
layer_id
=
None
,
hypernetwork
=
None
,
act_ck
=
False
,
diff_hypernets
=
False
,
interleaving_layers
=
False
,
every_n
=
5
,
cache
=
False
,
kv
=
None
):
residual
=
x
if
act_ck
:
x
=
ck
(
self
.
ln_preattn
,
x
)
attn_out
,
kv
=
ck
(
self
.
attn
,
x
,
kv
,
cache
)
#attn_out, kv = self.attn(x, kv=kv, cache=cache)
else
:
x
=
self
.
ln_preattn
(
x
)
attn_out
,
kv
=
self
.
attn
(
x
,
kv
=
kv
,
cache
=
cache
)
if
hypernetwork
:
if
diff_hypernets
:
if
interleaving_layers
and
layer_id
%
every_n
==
0
:
if
self
.
tick
:
hyper_out
=
hypernetwork
[
0
](
x
)
self
.
tick
=
False
else
:
hyper_out
=
hypernetwork
[
1
](
x
)
self
.
tick
=
True
elif
layer_id
%
every_n
==
0
:
hyper_out
=
hypernetwork
[(
layer_id
//
every_n
)
-
1
](
x
)
else
:
if
layer_id
%
every_n
==
0
:
hyper_out
=
hypernetwork
(
x
)
ff_out
=
self
.
ff
(
x
,
act_ck
)
#order of addition matters, i had no idea... fixed a bug here.
x
=
attn_out
+
ff_out
+
residual
#x = residual + attn_out + ff_out -> doesn't match.
if
hypernetwork
and
layer_id
%
every_n
==
0
:
x
=
x
+
hyper_out
return
x
,
kv
class
PerceiverARModel
(
base_lm
.
BaseModel
):
def
__init__
(
self
,
user_config
,
**
kwargs
):
self
.
default_config
=
{
'n_layer'
:
6
,
'n_head'
:
8
,
'n_tokens'
:
2048
,
'hidden_dim'
:
512
,
'vocab_dim'
:
50400
,
'eps'
:
1e-5
,
'device'
:
torch
.
device
(
'cuda'
),
'dtype'
:
torch
.
float16
,
'Layer'
:
PerceiverARLayer
,
'activation'
:
gelu_new
,
'SelfAttention'
:
SelfAttention
,
'FeedForward'
:
FeedForward
,
}
base_lm
.
BaseModel
.
__init__
(
self
,
user_config
,
**
kwargs
)
scripts/cudagraph.py
View file @
482a7bae
...
@@ -5,7 +5,7 @@ from time import perf_counter, perf_counter_ns
...
@@ -5,7 +5,7 @@ from time import perf_counter, perf_counter_ns
import
numpy
as
np
import
numpy
as
np
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
basedformer.hypernet
import
*
from
basedformer.
models.
hypernet
import
*
import
sys
import
sys
#replicating timeit magic function of ipython
#replicating timeit magic function of ipython
def
timeit
(
func
,
r
=
1
,
n
=
5
,
quiet
=
False
,
function
=
None
,
do_tqdm
=
False
,
first
=
True
):
def
timeit
(
func
,
r
=
1
,
n
=
5
,
quiet
=
False
,
function
=
None
,
do_tqdm
=
False
,
first
=
True
):
...
...
scripts/q_only.py
0 → 100644
View file @
482a7bae
from
basedformer
import
models
,
utils
import
torch
config
=
{
"n_layer"
:
28
,
"n_head"
:
16
,
"hidden_dim"
:
4096
,
}
config
=
{
"n_layer"
:
40
,
"n_head"
:
40
,
"hidden_dim"
:
5120
,
}
config_q
=
{
**
config
,
"q_only"
:
True
}
#init param matched GPT
gpt
=
models
.
fairseq
.
GPTFairModel
(
config
)
.
cuda
()
.
half
()
utils
.
print_parameters
(
gpt
)
bsz
=
3
cached_seq
=
1000
y
=
torch
.
randint
(
0
,
50256
,
(
bsz
,
cached_seq
))
.
long
()
.
cuda
()
x
=
torch
.
randint
(
0
,
50256
,
(
bsz
,
1
))
.
long
()
.
cuda
()
cache_f
=
torch
.
rand
(
bsz
,
config
[
"n_head"
],
cached_seq
,
config
[
"hidden_dim"
]
//
config
[
"n_head"
])
.
cuda
()
.
half
()
cache_f
=
(
cache_f
,
cache_f
)
cache_f
=
[
cache_f
for
_
in
range
(
config
[
"n_layer"
])]
print
(
len
(
cache_f
))
print
(
cache_f
[
0
][
1
]
.
shape
)
######
cache_q
=
torch
.
rand
(
bsz
,
1
,
cached_seq
,
config
[
"hidden_dim"
]
//
config
[
"n_head"
])
.
cuda
()
.
half
()
cache_q
=
(
cache_q
,
cache_q
)
cache_q
=
[
cache_q
for
_
in
range
(
config
[
"n_layer"
])]
print
(
cache_q
[
0
][
0
]
.
shape
)
with
torch
.
no_grad
():
#print("Initial Context GPT:")
#utils.timeit(func=lambda: gpt(y), r=10, n=10)
out
=
gpt
(
y
,
cache
=
True
)
print
(
out
[
1
][
0
][
0
]
.
shape
)
print
(
"GPT"
)
utils
.
timeit
(
func
=
lambda
:
gpt
(
x
,
kv
=
cache_f
),
r
=
10
,
n
=
10
)
'''
del gpt
#init param matched Q-Only
gpt_q = models.gptj.GPTJModel(config_q).cuda().half()
utils.print_parameters(gpt_q)
with torch.no_grad():
#print("Initial Context GPT-Q:")
#utils.timeit(func=lambda: gpt_q(y), r=10, n=10)
out_q = gpt_q(y, cache=True)
print("GPT-Q:")
utils.timeit(func=lambda: gpt_q(x, kv=cache_q), r=10, n=10)
'''
\ No newline at end of file
train.py
View file @
482a7bae
...
@@ -5,7 +5,7 @@ import torch.cuda.amp as amp
...
@@ -5,7 +5,7 @@ import torch.cuda.amp as amp
import
torch.optim
as
optim
import
torch.optim
as
optim
from
pathlib
import
Path
from
pathlib
import
Path
from
torch.utils
import
data
from
torch.utils
import
data
from
basedformer
import
optimizer
,
utils
,
gptj
,
noemblm
,
gpt2
from
basedformer
import
optimizer
,
utils
,
models
,
lm_utils
import
yaml
import
yaml
import
sys
import
sys
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -14,12 +14,17 @@ import wandb
...
@@ -14,12 +14,17 @@ import wandb
import
numpy
as
np
import
numpy
as
np
import
os
import
os
def
softmax_activation
(
x
):
return
F
.
log_softmax
(
x
,
dim
=-
1
)
model_config
=
{
model_config
=
{
"n_layer"
:
3
,
"n_layer"
:
12
,
"n_head"
:
1
6
,
"n_head"
:
1
2
,
"hidden_dim"
:
1024
,
"hidden_dim"
:
768
,
"vocab_dim"
:
50400
,
"vocab_dim"
:
50400
,
"eps"
:
1e-5
,
"eps"
:
1e-5
,
"q_only"
:
True
,
"activation"
:
torch
.
nn
.
GELU
(),
}
}
# we need 250 batch size to train the small GPT.
# we need 250 batch size to train the small GPT.
...
@@ -29,9 +34,9 @@ train_config = {
...
@@ -29,9 +34,9 @@ train_config = {
#"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map",
"save_path"
:
"/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/owt2gptj-nopos-tokenshift-superhighlr-residualgate-3L-16A-1024H"
,
"save_path"
:
"/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/owt2gptj-nopos-tokenshift-superhighlr-residualgate-3L-16A-1024H"
,
"do_save"
:
False
,
"do_save"
:
False
,
"run_name"
:
"gptj-
nopos-tokenshift-superhighlr-owt2-72M-3L-16A-1024H-fp16AMP-512ctx-16bs-1e-4lrinit
"
,
"run_name"
:
"gptj-
owt2-512ctx-12L-12H-768H-16bs-1e-4lr-q-only-smallattneveryotherlayer
"
,
"lr"
:
1e-
3
,
"lr"
:
1e-
4
,
"end_lr"
:
1e-
3
,
"end_lr"
:
1e-
4
,
"warmup_steps"
:
100
,
"warmup_steps"
:
100
,
"bs"
:
16
,
"bs"
:
16
,
"gas"
:
1
,
"gas"
:
1
,
...
@@ -47,7 +52,8 @@ gas = train_config["gas"]
...
@@ -47,7 +52,8 @@ gas = train_config["gas"]
Path
(
train_config
[
"save_path"
])
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
Path
(
train_config
[
"save_path"
])
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
#model = GPTModel.gpt2_init(model_config).cuda().float()
#model = GPTModel.gpt2_init(model_config).cuda().float()
model
=
noemblm
.
GPTNEBaseLM
.
init
(
model_config
)
.
cuda
()
.
float
()
model
=
lm_utils
.
init
(
models
.
fast
.
GPTJModel
,
model_config
)
.
cuda
()
.
float
()
utils
.
print_parameters
(
model
)
model
.
train
()
model
.
train
()
cp_list
=
sorted
(
os
.
listdir
(
train_config
[
"save_path"
]),
key
=
lambda
x
:
int
(
x
.
split
(
"_"
)[
-
1
]))
cp_list
=
sorted
(
os
.
listdir
(
train_config
[
"save_path"
]),
key
=
lambda
x
:
int
(
x
.
split
(
"_"
)[
-
1
]))
...
@@ -87,7 +93,7 @@ for input_ids, labels in t:
...
@@ -87,7 +93,7 @@ for input_ids, labels in t:
loss
=
0
loss
=
0
for
x
in
range
(
train_config
[
"gas"
]):
for
x
in
range
(
train_config
[
"gas"
]):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
train_config
[
"amp"
],
dtype
=
torch
.
float16
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
train_config
[
"amp"
],
dtype
=
torch
.
float16
):
logits
=
model
.
lm
(
input_ids
[
x
*
bs
:(
x
+
1
)
*
bs
,
:
512
]
.
cuda
(),
act_ck
=
False
)
logits
=
model
(
input_ids
[
x
*
bs
:(
x
+
1
)
*
bs
,
:
512
]
.
cuda
(),
act_ck
=
False
)
#print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
#print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
#roll down the sequence
#roll down the sequence
logits
=
logits
.
view
(
-
1
,
logits
.
shape
[
-
1
])
logits
=
logits
.
view
(
-
1
,
logits
.
shape
[
-
1
])
...
@@ -117,7 +123,7 @@ for input_ids, labels in t:
...
@@ -117,7 +123,7 @@ for input_ids, labels in t:
opt
.
zero_grad
()
opt
.
zero_grad
()
sec_per_step
=
(
time
.
perf_counter
()
-
timex
)
sec_per_step
=
(
time
.
perf_counter
()
-
timex
)
step_per_sec
=
(
1.
/
sec_per_step
)
step_per_sec
=
(
1.
/
sec_per_step
)
tokens_per_sec
=
(
step_per_sec
*
1024
)
*
bs
*
gas
tokens_per_sec
=
(
step_per_sec
*
512
)
*
bs
*
gas
t
.
set_description
(
f
"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}"
)
t
.
set_description
(
f
"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}"
)
wandb
.
log
(
wandb
.
log
(
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment