Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
074d25c5
Commit
074d25c5
authored
Apr 29, 2022
by
novelailab
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
sampling working, gpt-neo(x) init
parent
9223ef70
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
651 additions
and
33 deletions
+651
-33
basedformer/gptneo.py
basedformer/gptneo.py
+189
-0
basedformer/gptneox.py
basedformer/gptneox.py
+243
-0
basedformer/utils.py
basedformer/utils.py
+5
-3
eval_tasks/__init__.py
eval_tasks/__init__.py
+0
-0
scripts/test_cache.py
scripts/test_cache.py
+214
-30
No files found.
basedformer/gptneo.py
0 → 100644
View file @
074d25c5
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
basedformer.utils
import
*
from
torch.utils.checkpoint
import
checkpoint
as
ck
from
einops
import
rearrange
,
repeat
try
:
from
collections.abc
import
MutableMapping
except
ImportError
:
from
collections
import
MutableMapping
import
os
from
pathlib
import
Path
import
math
from
basedformer
import
lm_base
def
shift_tokens
(
x
,
amt
,
eps
=
1e-5
):
n
,
device
=
x
.
shape
[
1
],
x
.
device
cumsum
=
x
.
cumsum
(
dim
=
1
)
*
x
,
x_pass
=
x
.
chunk
(
amt
+
1
,
dim
=
-
1
)
*
x_cumsum
,
_
=
cumsum
.
chunk
(
amt
+
1
,
dim
=
-
1
)
amts
=
2
**
torch
.
arange
(
amt
)
amts
=
amts
.
tolist
()
shifts
=
[]
denom
=
torch
.
arange
(
n
,
device
=
device
)
for
x_chunk
,
x_cumsum_chunk
,
amt
in
zip
(
x
,
x_cumsum
,
amts
):
shifted_chunk
=
shift
(
x_cumsum_chunk
,
amt
,
dim
=
-
2
)
-
shift
(
x_cumsum_chunk
,
2
*
amt
,
dim
=
-
2
)
shifted_denom
=
shift
(
denom
,
amt
,
dim
=
-
1
)
-
shift
(
denom
,
2
*
amt
,
dim
=
-
1
)
shifted_denom
=
rearrange
(
shifted_denom
,
'n -> () n ()'
)
normed_shifted_x
=
shifted_chunk
/
(
shifted_denom
+
eps
)
shifts
.
append
(
normed_shifted_x
)
return
torch
.
cat
((
*
shifts
,
x_pass
),
dim
=
-
1
)
def
shift
(
x
,
amt
,
dim
=
-
1
):
return
F
.
pad
(
x
,
(
*
((
0
,
0
)
*
(
-
dim
-
1
)),
amt
,
-
amt
),
value
=
0.
)
def
_attn
(
query
,
key
,
value
,
causal_mask
,
masked_bias
,
attention_mask
=
None
,
scale_attn
=
None
):
attn_weights
=
torch
.
matmul
(
query
,
key
.
transpose
(
-
1
,
-
2
))
attn_weights
=
torch
.
where
(
causal_mask
,
attn_weights
,
masked_bias
.
to
(
attn_weights
.
dtype
))
attn_weights
=
attn_weights
/
scale_attn
if
attention_mask
is
not
None
:
attn_weights
=
attn_weights
+
attention_mask
attn_weights
=
F
.
softmax
(
attn_weights
,
dim
=-
1
)
attn_weights
=
attn_weights
.
to
(
value
.
dtype
)
attn_output
=
torch
.
matmul
(
attn_weights
,
value
)
.
to
(
value
.
dtype
)
return
attn_output
class
SelfAttention
(
nn
.
Module
):
# Code copied from HF, might want to sanity check later.
def
__init__
(
self
,
hidden_dim
,
n_head
,
device
,
dtype
):
nn
.
Module
.
__init__
(
self
)
max_positions
=
2049
bias
=
torch
.
tril
(
torch
.
ones
((
max_positions
,
max_positions
),
dtype
=
torch
.
uint8
,
requires_grad
=
False
))
.
view
(
1
,
1
,
max_positions
,
max_positions
)
.
bool
()
self
.
head_dim
=
hidden_dim
//
n_head
self
.
rotary_dim
=
self
.
head_dim
//
4
self
.
hidden_dim
=
hidden_dim
self
.
n_head
=
n_head
self
.
register_buffer
(
"scale_attn"
,
torch
.
sqrt
(
torch
.
tensor
(
self
.
head_dim
,
requires_grad
=
False
)
.
float
()))
self
.
register_buffer
(
"bias"
,
bias
)
self
.
register_buffer
(
"masked_bias"
,
torch
.
tensor
(
-
1e9
,
requires_grad
=
False
))
#-1e10 is what mtj uses.
attn_bias
=
False
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
out_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
def
forward
(
self
,
x
,
kv
=
None
,
cache
=
False
):
B
,
S
,
H
=
x
.
shape
# batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
query
=
self
.
q_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
key
=
self
.
k_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
value
=
self
.
v_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
if
kv
:
k
,
v
=
kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
torch
.
cat
([
k
,
key
],
dim
=-
2
)
# cat key
torch
.
cat
([
v
,
value
],
dim
=-
2
)
# cat value
query_length
,
key_length
=
query
.
size
(
-
2
),
key
.
size
(
-
2
)
# seq_len, seq_len
causal_mask
=
self
.
bias
[:,
:,
key_length
-
query_length
:
key_length
,
:
key_length
]
x
=
_attn
(
query
,
key
,
value
,
causal_mask
,
self
.
masked_bias
,
None
,
self
.
scale_attn
)
x
=
x
.
transpose
(
1
,
2
)
.
contiguous
()
.
view
(
B
,
S
,
H
)
x
=
self
.
out_proj
(
x
)
if
cache
:
return
x
,
(
key
,
value
)
else
:
return
x
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
hidden_dim
,
activation
,
device
,
dtype
):
nn
.
Module
.
__init__
(
self
)
self
.
ff1
=
nn
.
Linear
(
dim
,
hidden_dim
,
device
=
device
,
dtype
=
dtype
)
self
.
ff2
=
nn
.
Linear
(
hidden_dim
,
dim
,
device
=
device
,
dtype
=
dtype
)
self
.
activation
=
activation
def
forward
(
self
,
x
,
act_ck
=
False
):
x
=
self
.
ff1
(
x
)
if
act_ck
:
x
=
ck
(
self
.
activation
,
x
)
else
:
x
=
self
.
activation
(
x
)
x
=
self
.
ff2
(
x
)
return
x
class
GPTNeoLayer
(
nn
.
Module
):
def
__init__
(
self
,
attn
,
ff
,
hidden_dim
,
n_head
,
eps
,
activation
,
device
,
dtype
):
nn
.
Module
.
__init__
(
self
)
self
.
hidden_dim
=
hidden_dim
self
.
ln_preattn
=
nn
.
LayerNorm
(
hidden_dim
,
eps
=
eps
,
device
=
device
,
dtype
=
dtype
)
self
.
ln_postattn
=
nn
.
LayerNorm
(
hidden_dim
,
eps
=
eps
,
device
=
device
,
dtype
=
dtype
)
self
.
ff
=
ff
(
dim
=
hidden_dim
,
hidden_dim
=
hidden_dim
*
4
,
activation
=
activation
,
device
=
device
,
dtype
=
dtype
)
self
.
attn
=
attn
(
hidden_dim
=
hidden_dim
,
n_head
=
n_head
,
device
=
device
,
dtype
=
dtype
)
self
.
tick
=
True
def
forward
(
self
,
x
,
layer_id
=
None
,
hypernetwork
=
None
,
act_ck
=
False
):
residual
=
x
if
act_ck
:
x
=
ck
(
self
.
ln_preattn
,
x
)
attn_out
=
ck
(
self
.
attn
,
x
)
else
:
x
=
self
.
ln_preattn
(
x
)
attn_out
=
self
.
attn
(
x
)
residual
=
residual
+
attn_out
x
=
self
.
ln_postattn
(
x
)
ff_out
=
self
.
ff
(
x
,
act_ck
)
x
=
residual
+
ff_out
return
x
class
GPTNeoModel
(
nn
.
Module
):
def
__init__
(
self
,
hidden_dim
,
n_layer
,
n_head
,
vocab_dim
,
eps
,
activation
=
gelu_new
,
Layer
=
GPTNeoLayer
,
device
=
"cuda"
,
dtype
=
torch
.
float16
,
**
kwargs
):
nn
.
Module
.
__init__
(
self
)
self
.
n_layer
=
n_layer
self
.
hidden_dim
=
hidden_dim
self
.
vocab_embed
=
nn
.
Embedding
(
vocab_dim
,
self
.
hidden_dim
,
device
=
device
,
dtype
=
dtype
)
self
.
ln_final
=
nn
.
LayerNorm
(
self
.
hidden_dim
,
eps
=
eps
,
device
=
device
,
dtype
=
dtype
)
self
.
layers
=
nn
.
ModuleList
([])
self
.
lm_head
=
nn
.
Linear
(
hidden_dim
,
vocab_dim
,
bias
=
True
)
for
_
in
range
(
n_layer
):
self
.
layers
.
append
(
Layer
(
attn
=
SelfAttention
,
ff
=
FeedForward
,
hidden_dim
=
hidden_dim
,
n_head
=
n_head
,
eps
=
eps
,
activation
=
activation
,
device
=
device
,
dtype
=
dtype
))
def
forward
(
self
,
x
,
hypernetwork
=
None
,
act_ck
=
False
):
x
=
self
.
get_embeds
(
x
,
hypernetwork
=
hypernetwork
,
act_ck
=
act_ck
)
x
=
self
.
lm_head
(
x
)
return
x
.
float
()
def
get_embeds
(
self
,
x
,
hypernetwork
=
None
,
act_ck
=
False
):
x
=
self
.
vocab_embed
(
x
)
for
layer_id
,
layer
in
enumerate
(
self
.
layers
):
x
=
layer
(
x
,
layer_id
=
layer_id
,
hypernetwork
=
hypernetwork
,
act_ck
=
act_ck
)
x
=
self
.
ln_final
(
x
)
return
x
class
GPTNeoBaseLM
(
lm_base
.
BaseLM
):
def
__init__
(
self
,
config
=
None
,
lm
=
None
):
nn
.
Module
.
__init__
(
self
)
lm_base
.
BaseLM
.
__init__
(
self
,
config
,
lm
)
self
.
model_class
=
GPTNeoModel
def
load_gpt_j
(
path
=
"models/6b"
,
state_dict
=
None
):
config
=
{
"n_layer"
:
28
,
"n_head"
:
16
,
"hidden_dim"
:
4096
,
"vocab_dim"
:
50400
,
"eps"
:
1e-5
}
model
=
GPTNeoBaseLM
.
load
(
config
,
path
,
state_dict
)
return
model
basedformer/gptneox.py
0 → 100644
View file @
074d25c5
from
typing
import
KeysView
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
basedformer.utils
import
*
from
torch.utils.checkpoint
import
checkpoint
as
ck
from
einops
import
rearrange
,
repeat
try
:
from
collections.abc
import
MutableMapping
except
ImportError
:
from
collections
import
MutableMapping
import
os
from
pathlib
import
Path
import
math
from
basedformer
import
lm_base
def
fixed_pos_embedding
(
dim
=
None
,
seq_len
=
None
,
x
=
None
):
if
x
is
None
:
x
=
torch
.
empty
(
0
)
inv_freq
=
1.
/
(
10000
**
(
torch
.
arange
(
0
,
dim
,
2
)
/
dim
))
.
to
(
x
.
dtype
)
.
to
(
x
.
device
)
sinusoid_inp
=
torch
.
einsum
(
'i , j -> i j'
,
torch
.
arange
(
seq_len
)
.
to
(
x
.
device
),
inv_freq
)
.
float
()
return
torch
.
sin
(
sinusoid_inp
),
torch
.
cos
(
sinusoid_inp
)
def
rotate_every_two
(
x
):
x1
=
x
[:,
:,
:,
::
2
]
x2
=
x
[:,
:,
:,
1
::
2
]
x
=
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
)
return
rearrange
(
x
,
'... d j -> ... (d j)'
)
def
apply_rotary_pos_emb
(
x
,
sincos
,
offset
=
0
):
sin
,
cos
=
map
(
lambda
t
:
repeat
(
t
[
offset
:
x
.
shape
[
1
]
+
offset
,:],
"n d -> () n () (d j)"
,
j
=
2
),
sincos
)
return
(
x
*
cos
)
+
(
rotate_every_two
(
x
)
*
sin
)
def
_attn
(
query
,
key
,
value
,
causal_mask
,
masked_bias
,
attention_mask
=
None
,
scale_attn
=
None
):
attn_weights
=
torch
.
matmul
(
query
,
key
.
transpose
(
-
1
,
-
2
))
attn_weights
=
torch
.
where
(
causal_mask
,
attn_weights
,
masked_bias
.
to
(
attn_weights
.
dtype
))
attn_weights
=
attn_weights
/
scale_attn
if
attention_mask
is
not
None
:
attn_weights
=
attn_weights
+
attention_mask
attn_weights
=
F
.
softmax
(
attn_weights
,
dim
=-
1
)
attn_weights
=
attn_weights
.
to
(
value
.
dtype
)
attn_output
=
torch
.
matmul
(
attn_weights
,
value
)
.
to
(
value
.
dtype
)
return
attn_output
class
SelfAttention
(
nn
.
Module
):
# Code copied from HF, might want to sanity check later.
def
__init__
(
self
,
hidden_dim
,
n_head
,
device
,
dtype
):
nn
.
Module
.
__init__
(
self
)
max_positions
=
2049
bias
=
torch
.
tril
(
torch
.
ones
((
max_positions
,
max_positions
),
dtype
=
torch
.
uint8
,
requires_grad
=
False
))
.
view
(
1
,
1
,
max_positions
,
max_positions
)
.
bool
()
self
.
head_dim
=
hidden_dim
//
n_head
self
.
rotary_dim
=
self
.
head_dim
//
4
self
.
hidden_dim
=
hidden_dim
self
.
n_head
=
n_head
self
.
register_buffer
(
"scale_attn"
,
torch
.
sqrt
(
torch
.
tensor
(
self
.
head_dim
,
requires_grad
=
False
)
.
float
()))
self
.
register_buffer
(
"bias"
,
bias
)
self
.
register_buffer
(
"masked_bias"
,
torch
.
tensor
(
-
1e9
,
requires_grad
=
False
))
#-1e10 is what mtj uses.
attn_bias
=
False
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
out_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
sin
,
cos
=
fixed_pos_embedding
(
dim
=
self
.
rotary_dim
,
seq_len
=
max_positions
)
self
.
register_buffer
(
"sin"
,
sin
)
self
.
register_buffer
(
"cos"
,
cos
)
def
forward
(
self
,
x
,
kv
=
None
,
cache
=
False
):
B
,
S
,
H
=
x
.
shape
# batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
# transpose q, k after rotary as rotary code accepts [b, s, h, h_d]
query
=
self
.
q_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
key
=
self
.
k_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
value
=
self
.
v_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
if
kv
:
offset
=
kv
[
0
]
.
shape
[
-
2
]
else
:
offset
=
0
if
self
.
rotary_dim
<
self
.
head_dim
:
k_rot
=
key
[:,
:,
:,
:
self
.
rotary_dim
]
k_pass
=
key
[:,
:,
:,
self
.
rotary_dim
:]
q_rot
=
query
[:,
:,
:,
:
self
.
rotary_dim
]
q_pass
=
query
[:,
:,
:,
self
.
rotary_dim
:]
k_rot
=
apply_rotary_pos_emb
(
k_rot
,
(
self
.
sin
,
self
.
cos
),
offset
=
offset
)
.
to
(
k_rot
.
dtype
)
q_rot
=
apply_rotary_pos_emb
(
q_rot
,
(
self
.
sin
,
self
.
cos
),
offset
=
offset
)
.
to
(
q_rot
.
dtype
)
key
=
torch
.
cat
([
k_rot
,
k_pass
],
dim
=-
1
)
query
=
torch
.
cat
([
q_rot
,
q_pass
],
dim
=-
1
)
else
:
key
=
apply_rotary_pos_emb
(
key
,
(
self
.
sin
,
self
.
cos
),
offset
=
offset
)
.
to
(
key
.
dtype
)
query
=
apply_rotary_pos_emb
(
query
,
(
self
.
sin
,
self
.
cos
),
offset
=
offset
)
.
to
(
query
.
dtype
)
query
=
query
.
transpose
(
1
,
2
)
key
=
key
.
transpose
(
1
,
2
)
if
kv
:
k
,
v
=
kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
key
=
torch
.
cat
([
k
,
key
],
dim
=-
2
)
# cat key
value
=
torch
.
cat
([
v
,
value
],
dim
=-
2
)
# cat value
query_length
,
key_length
=
query
.
size
(
-
2
),
key
.
size
(
-
2
)
#causal mask with generation in mind
causal_mask
=
self
.
bias
[:,
:,
key_length
-
query_length
:
key_length
,
:
key_length
]
x
=
_attn
(
query
,
key
,
value
,
causal_mask
,
self
.
masked_bias
,
None
,
self
.
scale_attn
)
x
=
x
.
transpose
(
1
,
2
)
.
contiguous
()
.
view
(
B
,
S
,
H
)
x
=
self
.
out_proj
(
x
)
if
cache
:
return
x
,
(
key
,
value
)
else
:
return
x
,
None
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
hidden_dim
,
activation
,
device
,
dtype
):
nn
.
Module
.
__init__
(
self
)
self
.
ff1
=
nn
.
Linear
(
dim
,
hidden_dim
,
device
=
device
,
dtype
=
dtype
)
self
.
ff2
=
nn
.
Linear
(
hidden_dim
,
dim
,
device
=
device
,
dtype
=
dtype
)
self
.
activation
=
activation
def
forward
(
self
,
x
,
act_ck
=
False
):
x
=
self
.
ff1
(
x
)
if
act_ck
:
x
=
ck
(
self
.
activation
,
x
)
else
:
x
=
self
.
activation
(
x
)
x
=
self
.
ff2
(
x
)
return
x
class
GPTNeoxLayer
(
nn
.
Module
):
def
__init__
(
self
,
attn
,
ff
,
hidden_dim
,
n_head
,
eps
,
activation
,
device
,
dtype
):
nn
.
Module
.
__init__
(
self
)
self
.
hidden_dim
=
hidden_dim
self
.
ln_preattn
=
nn
.
LayerNorm
(
hidden_dim
,
eps
=
eps
,
device
=
device
,
dtype
=
dtype
)
self
.
ff
=
ff
(
dim
=
hidden_dim
,
hidden_dim
=
hidden_dim
*
4
,
activation
=
activation
,
device
=
device
,
dtype
=
dtype
)
self
.
attn
=
attn
(
hidden_dim
=
hidden_dim
,
n_head
=
n_head
,
device
=
device
,
dtype
=
dtype
)
self
.
tick
=
True
def
forward
(
self
,
x
,
layer_id
=
None
,
hypernetwork
=
None
,
act_ck
=
False
,
diff_hypernets
=
False
,
interleaving_layers
=
False
,
every_n
=
5
,
cache
=
False
,
kv
=
None
):
residual
=
x
if
act_ck
:
x
=
ck
(
self
.
ln_preattn
,
x
)
attn_out
,
kv
=
ck
(
self
.
attn
,
x
,
kv
=
kv
,
cache
=
cache
)
else
:
x
=
self
.
ln_preattn
(
x
)
attn_out
,
kv
=
self
.
attn
(
x
,
kv
=
kv
,
cache
=
cache
)
if
hypernetwork
:
if
diff_hypernets
:
if
interleaving_layers
and
layer_id
%
every_n
==
0
:
if
self
.
tick
:
hyper_out
=
hypernetwork
[
0
](
x
)
self
.
tick
=
False
else
:
hyper_out
=
hypernetwork
[
1
](
x
)
self
.
tick
=
True
elif
layer_id
%
every_n
==
0
:
hyper_out
=
hypernetwork
[(
layer_id
//
every_n
)
-
1
](
x
)
else
:
if
layer_id
%
every_n
==
0
:
hyper_out
=
hypernetwork
(
x
)
ff_out
=
self
.
ff
(
x
,
act_ck
)
#order of addition matters, i had no idea... fixed a bug here.
x
=
attn_out
+
ff_out
+
residual
#x = residual + attn_out + ff_out -> doesn't match.
if
hypernetwork
and
layer_id
%
every_n
==
0
:
x
=
x
+
hyper_out
return
x
,
kv
class
GPTNeoXModel
(
nn
.
Module
):
def
__init__
(
self
,
hidden_dim
,
n_layer
,
n_head
,
vocab_dim
,
eps
,
activation
=
gelu_new
,
Layer
=
GPTNeoXLayer
,
device
=
"cuda"
,
dtype
=
torch
.
float16
,
**
kwargs
):
nn
.
Module
.
__init__
(
self
)
self
.
n_layer
=
n_layer
self
.
hidden_dim
=
hidden_dim
self
.
vocab_embed
=
nn
.
Embedding
(
vocab_dim
,
self
.
hidden_dim
,
device
=
device
,
dtype
=
dtype
)
self
.
ln_final
=
nn
.
LayerNorm
(
self
.
hidden_dim
,
eps
=
eps
,
device
=
device
,
dtype
=
dtype
)
self
.
layers
=
nn
.
ModuleList
([])
self
.
lm_head
=
nn
.
Linear
(
hidden_dim
,
vocab_dim
,
bias
=
True
)
for
_
in
range
(
n_layer
):
self
.
layers
.
append
(
Layer
(
attn
=
SelfAttention
,
ff
=
FeedForward
,
hidden_dim
=
hidden_dim
,
n_head
=
n_head
,
eps
=
eps
,
activation
=
activation
,
device
=
device
,
dtype
=
dtype
))
def
forward
(
self
,
x
,
hypernetwork
=
None
,
act_ck
=
False
,
kv
=
None
,
cache
=
False
):
x
,
kv
=
self
.
get_embeds
(
x
,
hypernetwork
=
hypernetwork
,
act_ck
=
act_ck
,
kv
=
kv
,
cache
=
cache
)
x
=
self
.
lm_head
(
x
)
if
kv
:
return
x
.
float
(),
kv
else
:
return
x
.
float
()
def
get_embeds
(
self
,
x
,
hypernetwork
=
None
,
act_ck
=
False
,
kv
=
None
,
cache
=
False
):
if
kv
is
None
:
kv
=
[
None
]
*
self
.
n_layer
kv_new
=
[]
x
=
self
.
vocab_embed
(
x
)
for
layer_id
,
layer
in
enumerate
(
self
.
layers
):
x
,
kvi
=
layer
(
x
,
layer_id
=
layer_id
,
hypernetwork
=
hypernetwork
,
act_ck
=
act_ck
,
kv
=
kv
[
layer_id
],
cache
=
cache
)
kv_new
.
append
(
kvi
)
x
=
self
.
ln_final
(
x
)
if
cache
:
return
x
,
kv_new
else
:
return
x
,
None
class
GPTNeoXBaseLM
(
lm_base
.
BaseLM
):
def
__init__
(
self
,
config
=
None
,
lm
=
None
):
nn
.
Module
.
__init__
(
self
)
lm_base
.
BaseLM
.
__init__
(
self
,
config
,
lm
)
self
.
model_class
=
GPTNeoXModel
def
load_gpt_j
(
path
=
"models/6b"
,
state_dict
=
None
):
config
=
{
"n_layer"
:
28
,
"n_head"
:
16
,
"hidden_dim"
:
4096
,
"vocab_dim"
:
50400
,
"eps"
:
1e-5
}
model
=
GPTNeoXBaseLM
.
load
(
config
,
path
,
state_dict
)
return
model
basedformer/utils.py
View file @
074d25c5
...
...
@@ -94,7 +94,7 @@ class SplitCheckpoint(MutableMapping):
def
copy
(
self
):
return
SplitCheckpoint
(
self
.
chkpt_dir
,
device
=
self
.
device
)
def
timeit
(
func
,
r
=
1
,
n
=
5
,
quiet
=
False
,
function
=
None
,
do_tqdm
=
False
,
first
=
True
,
cuda_blocking
=
Fals
e
):
def
timeit
(
func
,
r
=
1
,
n
=
5
,
quiet
=
False
,
function
=
None
,
do_tqdm
=
False
,
first
=
True
,
cuda_blocking
=
Tru
e
):
precision
=
'ns'
r_arr
=
np
.
empty
([
2
,
r
])
# [0] = mean, [1] = std
if
function
:
...
...
@@ -104,9 +104,11 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True
n_arr
=
np
.
empty
(
n
)
for
k
in
range
(
n
):
start
=
time
.
perf_counter_ns
()
torch
.
cuda
.
synchronize
()
if
cuda_blocking
:
torch
.
cuda
.
synchronize
()
func
()
torch
.
cuda
.
synchronize
()
if
cuda_blocking
:
torch
.
cuda
.
synchronize
()
n_arr
[
k
]
=
time
.
perf_counter_ns
()
-
start
if
not
first
:
...
...
eval_tasks/__init__.py
0 → 100644
View file @
074d25c5
scripts/test_cache.py
View file @
074d25c5
...
...
@@ -2,6 +2,7 @@ from basedformer import gptj
from
basedformer.utils
import
*
from
transformers
import
AutoTokenizer
from
icecream
import
ic
import
functorch
import
time
import
sys
...
...
@@ -17,19 +18,24 @@ def apply_top_k(logits, k):
# filter the logits that are not in the top-k to -inf
# keep top_k_ind and filter the rest
top_k_values
=
logits
.
topk
(
k
)[
0
]
remove_mask
=
logits
<
top_k_values
[:,
-
1
]
.
unsqueeze
(
0
)
remove_mask
=
logits
<
top_k_values
[:,
-
1
]
.
unsqueeze
(
-
1
)
logits
[
remove_mask
==
True
]
=
-
float
(
"inf"
)
return
logits
def
apply_top_p
(
logits
,
p
):
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
sorted
,
indices
=
torch
.
sort
(
logits
,
descending
=
True
)
cumulative_probs
=
torch
.
cumsum
(
sorted
,
dim
=-
1
)
cumulative_probs
=
cumulative_probs
.
scatter
(
dim
=-
1
,
index
=
indices
,
src
=
cumulative_probs
)
remove_mask
=
cumulative_probs
>
p
logits
[
remove_mask
==
True
]
=
-
float
(
"inf"
)
mask_tensor
=
cumulative_probs
>
p
# Shift the indices to the right to keep also the first token above the threshold
mask_tensor
[
...
,
1
:]
=
mask_tensor
[
...
,
:
-
1
]
.
clone
()
mask_tensor
[
...
,
0
]
=
0
mask_tensor
=
mask_tensor
.
scatter
(
dim
=-
1
,
index
=
indices
,
src
=
mask_tensor
)
logits
[
mask_tensor
==
True
]
=
-
float
(
"inf"
)
return
logits
def
apply_tfs
(
logits
,
tfs
):
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
sorted
,
indices
=
torch
.
sort
(
logits
,
descending
=
True
)
d
=
sorted
d
=
d
[:,
1
:]
-
d
[:,
:
-
1
]
...
...
@@ -37,83 +43,261 @@ def apply_tfs(logits, tfs):
d
=
d
.
abs
()
d
=
d
/
d
.
sum
(
dim
=-
1
)
.
view
(
1
,
-
1
)
.
T
cumulative_probs
=
torch
.
cumsum
(
d
,
dim
=-
1
)
cumulative_probs
=
cumulative_probs
.
scatter
(
dim
=-
1
,
index
=
indices
,
src
=
cumulative_probs
)
remove_mask
=
cumulative_probs
>
tfs
logits
[
remove_mask
==
True
]
=
-
float
(
"inf"
)
mask_tensor
=
torch
.
empty
(
indices
.
shape
)
.
cuda
()
mask_tensor
[:,
1
:
-
1
]
=
(
cumulative_probs
>
tfs
)[:,
:]
# Always remove last token
mask_tensor
[:,
-
1
:]
=
True
# Always keep the first token
mask_tensor
[:,
0
]
=
False
mask_tensor
=
mask_tensor
.
scatter
(
dim
=-
1
,
index
=
indices
,
src
=
mask_tensor
)
logits
[
mask_tensor
==
True
]
=
-
float
(
"inf"
)
return
logits
def
temperature
(
logits
,
temperature
):
def
apply_typical
(
logits
,
mass
=
0.9
):
scores
=
logits
normalized
=
torch
.
nn
.
functional
.
log_softmax
(
scores
,
dim
=-
1
)
p
=
torch
.
exp
(
normalized
)
ent
=
-
(
normalized
*
p
)
.
nansum
(
-
1
,
keepdim
=
True
)
# shift and sort
shifted_scores
=
torch
.
abs
((
-
normalized
)
-
ent
)
sorted_scores
,
sorted_indices
=
torch
.
sort
(
shifted_scores
,
descending
=
False
)
sorted_logits
=
scores
.
gather
(
-
1
,
sorted_indices
)
cumulative_probs
=
sorted_logits
.
softmax
(
dim
=-
1
)
.
cumsum
(
dim
=-
1
)
# Remove tokens with cumulative mass above the threshold
last_ind
=
(
cumulative_probs
<
mass
)
.
sum
(
dim
=
1
)
last_ind
[
last_ind
<
0
]
=
0
sorted_indices_to_remove
=
sorted_scores
>
sorted_scores
.
gather
(
1
,
last_ind
.
view
(
-
1
,
1
))
indices_to_remove
=
sorted_indices_to_remove
.
scatter
(
1
,
sorted_indices
,
sorted_indices_to_remove
)
scores
=
scores
.
masked_fill
(
indices_to_remove
,
-
float
(
"inf"
))
return
scores
def
apply_temp
(
logits
,
temperature
):
logits
=
logits
/
temperature
return
logits
def
generate
(
forward
,
prompt_tokens
,
temperature
,
tokens_to_generate
=
50
,
ops_list
=
[{
"temp"
:
0.9
}]):
def
rep_pen
(
input_ids
,
scores
,
penalty
,
m
=
3.33
,
penalize_last
=
250
,
alpha_frequency
=
None
,
alpha_presence
=
None
,
whitelist
=
None
,
):
scores
=
torch
.
log_softmax
(
scores
,
dim
=-
1
)
penalty
=
1.0
if
penalty
<
1.0
else
penalty
raw_penalty
=
penalty
penalize_last
=
None
if
not
m
is
None
and
not
penalize_last
is
None
and
penalize_last
>=
1
:
penalty
=
(
torch
.
arange
(
penalize_last
)
/
(
penalize_last
-
1
))
*
2.
-
1
penalty
=
(
m
*
penalty
)
/
(
1
+
torch
.
abs
(
penalty
)
*
(
m
-
1
))
penalty
=
1
+
((
penalty
+
1
)
/
2
)
.
unsqueeze
(
0
)
*
(
penalty
-
1
)
penalize_last
=
penalize_last
alpha_enable
=
alpha_frequency
is
not
None
or
alpha_presence
is
not
None
whitelist
=
None
whitelist_list
=
None
if
whitelist
is
not
None
:
whitelist_list
=
whitelist
##########
if
whitelist
is
None
and
whitelist_list
is
not
None
:
whitelist_list
=
list
(
filter
(
lambda
x
:
x
>=
0
and
x
<
scores
.
shape
[
1
],
whitelist_list
))
if
len
(
whitelist_list
)
>
0
:
whitelist
=
torch
.
tensor
(
whitelist_list
)
.
long
()
.
sort
()[
0
]
whitelist
=
whitelist
.
to
(
input_ids
.
device
)
if
whitelist
is
not
None
:
unpenalized
=
scores
.
gather
(
1
,
whitelist
.
view
(
1
,
-
1
))
if
raw_penalty
>
1.0
:
if
not
penalize_last
is
None
:
penality_len
=
min
(
input_ids
.
shape
[
1
],
penalize_last
)
input_ids
=
input_ids
[:,
-
penality_len
:]
score
=
torch
.
gather
(
scores
,
1
,
input_ids
)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
if
not
penalize_last
is
None
:
penalty
=
penalty
.
type
(
score
.
dtype
)
.
to
(
score
.
device
)
score
=
torch
.
where
(
score
<
0
,
score
*
penalty
[:,
-
penality_len
:],
score
/
penalty
[:,
-
penality_len
:])
else
:
score
=
torch
.
where
(
score
<
0
,
score
*
penalty
,
score
/
penalty
)
scores
.
scatter_
(
1
,
input_ids
,
score
)
if
alpha_enable
:
c
=
torch
.
zeros
(
scores
.
shape
)
.
long
()
.
to
(
input_ids
.
device
)
# unique only returns counts for first item in batch, so manually iterate
for
i
in
range
(
input_ids
.
shape
[
0
]):
if
penalize_last
is
not
None
:
token_input_ids
,
counts
=
torch
.
unique
(
input_ids
[
i
,
-
penalize_last
:],
sorted
=
True
,
return_counts
=
True
,
dim
=-
1
)
else
:
token_input_ids
,
counts
=
torch
.
unique
(
input_ids
[
i
],
sorted
=
True
,
return_counts
=
True
,
dim
=-
1
)
c
[
i
]
.
scatter_
(
0
,
token_input_ids
,
counts
)
if
alpha_frequency
:
scores
-=
c
*
alpha_frequency
if
alpha_presence
:
scores
[
c
>
0
]
-=
alpha_presence
if
whitelist
is
not
None
:
scores
.
scatter_
(
1
,
whitelist
.
view
(
1
,
-
1
),
unpenalized
)
return
scores
def
func_multinomial
(
x
):
torch
.
manual_seed
(
69
)
return
torch
.
multinomial
(
x
,
1
)
def
generate
(
forward
,
prompt_tokens
,
tokens_to_generate
=
50
,
ops_list
=
[{
"temp"
:
0.9
}]):
with
torch
.
no_grad
():
in_tokens
=
prompt_tokens
context
=
prompt_tokens
print
(
context
.
shape
)
kv
=
None
fully_deterministic
=
False
tokens_generated
=
[]
soft_required
=
[
"top_k"
,
"top_p"
]
#
soft_required = ["top_k", "top_p"]
op_map
=
{
"top_k"
:
apply_top_k
,
"top_p"
:
apply_top_p
,
"temp"
:
temperature
,
"tfs"
:
apply_tfs
"typical"
:
apply_typical
,
"temp"
:
apply_temp
,
"tfs"
:
apply_tfs
,
"rep_pen"
:
rep_pen
,
}
funcnomial
=
functorch
.
vmap
(
func_multinomial
,
randomness
=
"different"
)
for
_
in
range
(
tokens_to_generate
):
logits
,
kv
=
forward
(
in_tokens
,
cache
=
True
,
kv
=
kv
)
# always work on softmax logits to make sure all models
# behave similarly as logprobs can be quite different
# TODO: can break compatibility with novelai presets.
# logits should be the last token in the sequence
logits
=
logits
[:,
-
1
,
:]
logits
=
logits
[:,
-
1
,
:]
#get the last token in the seq
logits
=
torch
.
log_softmax
(
logits
,
dim
=-
1
)
#can save one softmax here by not applying softmax for the first op,
#need to take the softmax out of the necessary functions though
batch
=
[]
for
i
,
ops
in
enumerate
(
ops_list
):
batch
=
[]
item
=
logits
[
i
,
...
]
.
unsqueeze
(
0
)
ctx
=
context
[
i
,
...
]
.
unsqueeze
(
0
)
ic
(
"------"
)
for
op
,
value
in
ops
.
items
():
if
op
in
soft_required
:
item
=
torch
.
log_softmax
(
logits
[
i
,
:,
:],
dim
=-
1
)
ic
(
op
,
value
)
if
op
==
"rep_pen"
:
item
=
op_map
[
op
](
ctx
,
item
,
**
value
)
item
=
op_map
[
op
](
item
,
value
)
else
:
item
=
op_map
[
op
](
item
,
value
)
batch
.
append
(
item
)
logits
=
torch
.
cat
(
batch
,
dim
=
0
)
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
logits
=
torch
.
multinomial
(
logits
,
1
)
#fully_deterministic makes it deterministic in the batch
if
fully_deterministic
:
logits
=
logits
.
split
(
1
,
dim
=
0
)
logit_list
=
[]
for
logit
in
logits
:
torch
.
manual_seed
(
69
)
logit_list
.
append
(
torch
.
multinomial
(
logit
,
1
))
logits
=
torch
.
cat
(
logit_list
,
dim
=
0
)
else
:
torch
.
manual_seed
(
69
)
logits
=
torch
.
multinomial
(
logits
,
1
)
context
=
torch
.
cat
([
context
,
logits
],
dim
=
1
)
in_tokens
=
logits
return
context
def
generate_real_batched
(
forward
,
prompt_tokens
,
tokens_to_generate
=
50
,
ops
=
{
"temp"
:
0.9
}):
with
torch
.
no_grad
():
in_tokens
=
prompt_tokens
kv
=
None
fully_deterministic
=
False
tokens_generated
=
[]
op_map
=
{
"top_k"
:
apply_top_k
,
"top_p"
:
apply_top_p
,
"typical"
:
apply_typical
,
"temp"
:
apply_temp
,
"tfs"
:
apply_tfs
}
for
_
in
range
(
tokens_to_generate
):
logits
,
kv
=
forward
(
in_tokens
,
cache
=
True
,
kv
=
kv
)
logits
=
logits
[:,
-
1
,
:]
#get the last token in the seq
logits
=
torch
.
log_softmax
(
logits
,
dim
=-
1
)
for
op
,
value
in
ops
.
items
():
logits
=
op_map
[
op
](
logits
,
value
)
.
float
()
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
.
float
()
if
fully_deterministic
:
logits
=
logits
.
split
(
1
,
dim
=
0
)
logit_list
=
[]
for
logit
in
logits
:
torch
.
manual_seed
(
69
)
logit_list
.
append
(
torch
.
multinomial
(
logit
,
1
))
logits
=
torch
.
cat
(
logit_list
,
dim
=
0
)
else
:
torch
.
manual_seed
(
69
)
logits
=
torch
.
multinomial
(
logits
,
1
)
in_tokens
=
logits
tokens_generated
.
append
(
logits
)
tokens_generated
=
torch
.
cat
(
tokens_generated
,
dim
=-
1
)
return
tokens_generated
def
main
():
bsz
=
4
gen_len
=
250
torch
.
manual_seed
(
69
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
'gpt2'
)
prompt
=
"""I fucked her with my huge donut, when she seen my donut she went"""
prompt
=
"You hated the elves enough that if you seen one of them in the forest you would just slice their throats."
tokens
=
tokenizer
.
encode
(
prompt
)
print
(
"Prompt:"
)
for
x
in
range
(
len
(
tokens
)):
print
(
tokenizer
.
decode
([
tokens
[
x
]]),
end
=
" | "
)
print
(
"
\n
Generation:"
)
tokens
=
torch
.
LongTensor
(
tokens
)
.
unsqueeze
(
0
)
.
cuda
()
tokens
=
[
tokens
]
*
bsz
#tokens = torch.cat([tokens, tokens], dim=0)
tokens
=
torch
.
cat
(
tokens
,
dim
=
0
)
t
=
time
.
perf_counter
()
model
=
gptj
.
load_gpt_j
()
.
cuda
()
.
half
()
.
eval
()
model
=
model
.
lm
ic
(
time
.
perf_counter
()
-
t
)
rep_pen
=
{
"penalty"
:
1000000
,
}
ops
=
{
"
top_k"
:
40
,
"top_
p"
:
0.9
,
"temp"
:
0.
9
,
"
rep_pen"
:
rep_pen
,
"top_
k"
:
50
,
"temp"
:
0.
8
,
}
ops_list
=
[
ops
]
*
bsz
tokens_generated
=
generate
(
model
.
forward
,
tokens
,
40
,
ops
=
ops
)
tokens_generated
=
tokenizer
.
decode
(
tokens_generated
.
squeeze
()
.
tolist
())
tokens_generated
=
generate
(
model
.
forward
,
tokens
,
gen_len
,
ops_list
=
ops_list
)
#tokens_generated_batched = generate_real_batched(model.forward, tokens, gen_len, ops=ops)
print
(
tokens_generated
.
shape
)
ic
(
prompt
)
ic
(
tokens_generated
)
tokens_generated
=
tokenizer
.
batch_decode
(
tokens_generated
.
cpu
()
.
numpy
())
for
gen
in
tokens_generated
:
print
(
str
(
gen
))
print
(
"==========================================================="
)
#ic(tokenizer.batch_decode(tokens_generated_batched.cpu().numpy()))
#timeit(lambda: generate(model.forward, tokens, 30, ops_list=ops_list), n=30)
#timeit(lambda: generate_real_batched(model.forward, tokens, 30, ops=ops), n=30)
if
__name__
==
"__main__"
:
main
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment