Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
86e815ab
Commit
86e815ab
authored
Apr 12, 2022
by
novelailab
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
upstream
parent
1ab8bac6
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
507 additions
and
33 deletions
+507
-33
basedformer/gpt2.py
basedformer/gpt2.py
+213
-0
basedformer/lm_base.py
basedformer/lm_base.py
+1
-1
basedformer/noemblm.py
basedformer/noemblm.py
+228
-0
train.py
train.py
+65
-32
No files found.
basedformer/gpt2.py
0 → 100644
View file @
86e815ab
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
basedformer.utils
import
*
from
torch.utils.checkpoint
import
checkpoint
as
ck
from
einops
import
rearrange
,
repeat
try
:
from
collections.abc
import
MutableMapping
except
ImportError
:
from
collections
import
MutableMapping
import
os
from
pathlib
import
Path
import
math
from
basedformer
import
lm_base
def
shift_tokens
(
x
,
amt
,
eps
=
1e-5
):
n
,
device
=
x
.
shape
[
1
],
x
.
device
cumsum
=
x
.
cumsum
(
dim
=
1
)
*
x
,
x_pass
=
x
.
chunk
(
amt
+
1
,
dim
=
-
1
)
*
x_cumsum
,
_
=
cumsum
.
chunk
(
amt
+
1
,
dim
=
-
1
)
amts
=
2
**
torch
.
arange
(
amt
)
amts
=
amts
.
tolist
()
shifts
=
[]
denom
=
torch
.
arange
(
n
,
device
=
device
)
for
x_chunk
,
x_cumsum_chunk
,
amt
in
zip
(
x
,
x_cumsum
,
amts
):
shifted_chunk
=
shift
(
x_cumsum_chunk
,
amt
,
dim
=
-
2
)
-
shift
(
x_cumsum_chunk
,
2
*
amt
,
dim
=
-
2
)
shifted_denom
=
shift
(
denom
,
amt
,
dim
=
-
1
)
-
shift
(
denom
,
2
*
amt
,
dim
=
-
1
)
shifted_denom
=
rearrange
(
shifted_denom
,
'n -> () n ()'
)
normed_shifted_x
=
shifted_chunk
/
(
shifted_denom
+
eps
)
shifts
.
append
(
normed_shifted_x
)
return
torch
.
cat
((
*
shifts
,
x_pass
),
dim
=
-
1
)
def
shift
(
x
,
amt
,
dim
=
-
1
):
return
F
.
pad
(
x
,
(
*
((
0
,
0
)
*
(
-
dim
-
1
)),
amt
,
-
amt
),
value
=
0.
)
def
_split_heads
(
tensor
,
num_heads
,
attn_head_size
,
rotary
):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape
=
tensor
.
size
()[:
-
1
]
+
(
num_heads
,
attn_head_size
)
tensor
=
tensor
.
view
(
*
new_shape
)
if
rotary
:
return
tensor
if
len
(
tensor
.
shape
)
==
5
:
return
tensor
.
permute
(
0
,
1
,
3
,
2
,
4
)
# (batch, blocks, head, block_length, head_features)
elif
len
(
tensor
.
shape
)
==
4
:
return
tensor
.
permute
(
0
,
2
,
1
,
3
)
# (batch, head, seq_length, head_features)
else
:
raise
ValueError
(
f
"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}"
)
def
_merge_heads
(
tensor
,
num_heads
,
attn_head_size
):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
if
len
(
tensor
.
shape
)
==
5
:
tensor
=
tensor
.
permute
(
0
,
1
,
3
,
2
,
4
)
.
contiguous
()
elif
len
(
tensor
.
shape
)
==
4
:
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
)
.
contiguous
()
else
:
raise
ValueError
(
f
"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}"
)
new_shape
=
tensor
.
size
()[:
-
2
]
+
(
num_heads
*
attn_head_size
,)
return
tensor
.
view
(
new_shape
)
def
_attn
(
query
,
key
,
value
,
causal_mask
,
masked_bias
,
attention_mask
=
None
,
scale_attn
=
None
):
attn_weights
=
torch
.
matmul
(
query
,
key
.
transpose
(
-
1
,
-
2
))
attn_weights
=
torch
.
where
(
causal_mask
,
attn_weights
,
masked_bias
.
to
(
attn_weights
.
dtype
))
attn_weights
=
attn_weights
/
scale_attn
if
attention_mask
is
not
None
:
attn_weights
=
attn_weights
+
attention_mask
attn_weights
=
F
.
softmax
(
attn_weights
,
dim
=-
1
)
attn_weights
=
attn_weights
.
to
(
value
.
dtype
)
attn_output
=
torch
.
matmul
(
attn_weights
,
value
)
.
to
(
value
.
dtype
)
return
attn_output
class
SelfAttention
(
nn
.
Module
):
# Code copied from HF, might want to sanity check later.
def
__init__
(
self
,
hidden_dim
,
n_head
,
device
,
dtype
):
nn
.
Module
.
__init__
(
self
)
max_positions
=
2049
bias
=
torch
.
tril
(
torch
.
ones
((
max_positions
,
max_positions
),
dtype
=
torch
.
uint8
,
requires_grad
=
False
))
.
view
(
1
,
1
,
max_positions
,
max_positions
)
.
bool
()
self
.
head_dim
=
hidden_dim
//
n_head
self
.
rotary_dim
=
self
.
head_dim
//
4
self
.
hidden_dim
=
hidden_dim
self
.
n_head
=
n_head
self
.
register_buffer
(
"scale_attn"
,
torch
.
sqrt
(
torch
.
tensor
(
self
.
head_dim
,
requires_grad
=
False
)
.
float
()))
self
.
register_buffer
(
"bias"
,
bias
)
self
.
register_buffer
(
"masked_bias"
,
torch
.
tensor
(
-
1e9
,
requires_grad
=
False
))
#-1e10 is what mtj uses.
attn_bias
=
False
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
out_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
def
forward
(
self
,
x
):
query
=
self
.
q_proj
(
x
)
key
=
self
.
k_proj
(
x
)
value
=
self
.
v_proj
(
x
)
query
=
_split_heads
(
query
,
self
.
n_head
,
self
.
head_dim
,
True
)
key
=
_split_heads
(
key
,
self
.
n_head
,
self
.
head_dim
,
True
)
value
=
_split_heads
(
value
,
self
.
n_head
,
self
.
head_dim
,
False
)
key
=
key
.
permute
(
0
,
2
,
1
,
3
)
query
=
query
.
permute
(
0
,
2
,
1
,
3
)
query_length
,
key_length
=
query
.
size
(
-
2
),
key
.
size
(
-
2
)
causal_mask
=
self
.
bias
[:,
:,
key_length
-
query_length
:
key_length
,
:
key_length
]
x
=
_attn
(
query
,
key
,
value
,
causal_mask
,
self
.
masked_bias
,
None
,
self
.
scale_attn
)
x
=
_merge_heads
(
x
,
self
.
n_head
,
self
.
head_dim
)
x
=
self
.
out_proj
(
x
)
return
x
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
hidden_dim
,
activation
,
device
,
dtype
):
nn
.
Module
.
__init__
(
self
)
self
.
ff1
=
nn
.
Linear
(
dim
,
hidden_dim
,
device
=
device
,
dtype
=
dtype
)
self
.
ff2
=
nn
.
Linear
(
hidden_dim
,
dim
,
device
=
device
,
dtype
=
dtype
)
self
.
activation
=
activation
def
forward
(
self
,
x
,
act_ck
=
False
):
x
=
self
.
ff1
(
x
)
if
act_ck
:
x
=
ck
(
self
.
activation
,
x
)
else
:
x
=
self
.
activation
(
x
)
x
=
self
.
ff2
(
x
)
return
x
class
GPT2Layer
(
nn
.
Module
):
def
__init__
(
self
,
attn
,
ff
,
hidden_dim
,
n_head
,
eps
,
activation
,
device
,
dtype
):
nn
.
Module
.
__init__
(
self
)
self
.
hidden_dim
=
hidden_dim
self
.
ln_preattn
=
nn
.
LayerNorm
(
hidden_dim
,
eps
=
eps
,
device
=
device
,
dtype
=
dtype
)
self
.
ln_postattn
=
nn
.
LayerNorm
(
hidden_dim
,
eps
=
eps
,
device
=
device
,
dtype
=
dtype
)
self
.
ff
=
ff
(
dim
=
hidden_dim
,
hidden_dim
=
hidden_dim
*
4
,
activation
=
activation
,
device
=
device
,
dtype
=
dtype
)
self
.
attn
=
attn
(
hidden_dim
=
hidden_dim
,
n_head
=
n_head
,
device
=
device
,
dtype
=
dtype
)
self
.
tick
=
True
def
forward
(
self
,
x
,
layer_id
=
None
,
hypernetwork
=
None
,
act_ck
=
False
):
residual
=
x
if
act_ck
:
x
=
ck
(
self
.
ln_preattn
,
x
)
attn_out
=
ck
(
self
.
attn
,
x
)
else
:
x
=
self
.
ln_preattn
(
x
)
attn_out
=
self
.
attn
(
x
)
residual
=
residual
+
attn_out
x
=
self
.
ln_postattn
(
x
)
ff_out
=
self
.
ff
(
x
,
act_ck
)
x
=
residual
+
ff_out
return
x
class
GPT2Model
(
nn
.
Module
):
def
__init__
(
self
,
hidden_dim
,
n_layer
,
n_head
,
vocab_dim
,
eps
,
activation
=
gelu_new
,
Layer
=
GPT2Layer
,
device
=
"cuda"
,
dtype
=
torch
.
float16
,
**
kwargs
):
nn
.
Module
.
__init__
(
self
)
self
.
n_layer
=
n_layer
self
.
hidden_dim
=
hidden_dim
self
.
vocab_embed
=
nn
.
Embedding
(
vocab_dim
,
self
.
hidden_dim
,
device
=
device
,
dtype
=
dtype
)
self
.
ln_final
=
nn
.
LayerNorm
(
self
.
hidden_dim
,
eps
=
eps
,
device
=
device
,
dtype
=
dtype
)
self
.
layers
=
nn
.
ModuleList
([])
self
.
lm_head
=
nn
.
Linear
(
hidden_dim
,
vocab_dim
,
bias
=
True
)
for
_
in
range
(
n_layer
):
self
.
layers
.
append
(
Layer
(
attn
=
SelfAttention
,
ff
=
FeedForward
,
hidden_dim
=
hidden_dim
,
n_head
=
n_head
,
eps
=
eps
,
activation
=
activation
,
device
=
device
,
dtype
=
dtype
))
def
forward
(
self
,
x
,
hypernetwork
=
None
,
act_ck
=
False
):
x
=
self
.
get_embeds
(
x
,
hypernetwork
=
hypernetwork
,
act_ck
=
act_ck
)
x
=
self
.
lm_head
(
x
)
return
x
.
float
()
def
get_embeds
(
self
,
x
,
hypernetwork
=
None
,
act_ck
=
False
):
x
=
self
.
vocab_embed
(
x
)
for
layer_id
,
layer
in
enumerate
(
self
.
layers
):
x
=
layer
(
x
,
layer_id
=
layer_id
,
hypernetwork
=
hypernetwork
,
act_ck
=
act_ck
)
x
=
self
.
ln_final
(
x
)
return
x
class
GPT2BaseLM
(
lm_base
.
BaseLM
):
def
__init__
(
self
,
config
=
None
,
lm
=
None
):
nn
.
Module
.
__init__
(
self
)
lm_base
.
BaseLM
.
__init__
(
self
,
config
,
lm
)
self
.
model_class
=
GPT2Model
def
load_gpt_j
(
path
=
"models/6b"
,
state_dict
=
None
):
config
=
{
"n_layer"
:
28
,
"n_head"
:
16
,
"hidden_dim"
:
4096
,
"vocab_dim"
:
50400
,
"eps"
:
1e-5
}
model
=
GPT2BaseLM
.
load
(
config
,
path
,
state_dict
)
return
model
basedformer/lm_base.py
View file @
86e815ab
...
@@ -51,7 +51,7 @@ class BaseLM(nn.Module):
...
@@ -51,7 +51,7 @@ class BaseLM(nn.Module):
for
name
,
p
in
module
.
named_parameters
():
for
name
,
p
in
module
.
named_parameters
():
if
(
"ff2"
in
name
or
"out_proj"
in
name
)
and
"weight"
in
name
:
if
(
"ff2"
in
name
or
"out_proj"
in
name
)
and
"weight"
in
name
:
p
.
data
.
normal_
(
mean
=
0.0
,
std
=
(
0.02
/
math
.
sqrt
(
2
*
self
.
config
.
n_layer
)))
p
.
data
.
normal_
(
mean
=
0.0
,
std
=
(
0.02
/
math
.
sqrt
(
2
*
self
.
config
[
"n_layer"
]
)))
@
classmethod
@
classmethod
def
init
(
cls
,
config
):
def
init
(
cls
,
config
):
...
...
basedformer/noemblm.py
0 → 100644
View file @
86e815ab
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
basedformer.utils
import
*
from
torch.utils.checkpoint
import
checkpoint
as
ck
from
einops
import
rearrange
,
repeat
try
:
from
collections.abc
import
MutableMapping
except
ImportError
:
from
collections
import
MutableMapping
import
os
from
pathlib
import
Path
import
math
from
basedformer
import
lm_base
def
token_shift
(
x
,
window_size
=
1
):
size
=
x
.
size
()[
-
1
]
//
(
window_size
+
1
)
def
shift
(
x
,
t
,
s
):
return
nn
.
functional
.
pad
(
x
[:,
:
-
t
,
(
t
-
1
)
*
s
:
t
*
s
],
(
0
,
0
,
t
,
0
))[:,
:
x
.
size
()[
-
2
],
:]
time_shifts
=
[
shift
(
x
,
t
,
size
)
for
t
in
range
(
1
,
window_size
+
1
)]
current_x
=
[
x
[:,
:,
len
(
time_shifts
)
*
size
:]]
x
=
torch
.
cat
(
time_shifts
+
current_x
,
dim
=-
1
)
return
x
def
token_shift_no_mix
(
x
,
window_size
=
1
):
size
=
x
.
size
()[
-
1
]
//
(
window_size
+
1
)
def
shift
(
x
,
t
,
s
):
return
nn
.
functional
.
pad
(
x
[:,
:
-
t
,
:
s
],
(
0
,
0
,
t
,
0
))[:,
:
x
.
size
()[
-
2
],
:]
time_shifts
=
[
shift
(
x
,
t
,
size
)
for
t
in
range
(
window_size
,
0
,
-
1
)]
current_x
=
[
x
[:,
:,
len
(
time_shifts
)
*
size
:]]
x
=
torch
.
cat
(
time_shifts
+
current_x
,
dim
=-
1
)
return
x
def
_split_heads
(
tensor
,
num_heads
,
attn_head_size
,
rotary
):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape
=
tensor
.
size
()[:
-
1
]
+
(
num_heads
,
attn_head_size
)
tensor
=
tensor
.
view
(
*
new_shape
)
if
rotary
:
return
tensor
if
len
(
tensor
.
shape
)
==
5
:
return
tensor
.
permute
(
0
,
1
,
3
,
2
,
4
)
# (batch, blocks, head, block_length, head_features)
elif
len
(
tensor
.
shape
)
==
4
:
return
tensor
.
permute
(
0
,
2
,
1
,
3
)
# (batch, head, seq_length, head_features)
else
:
raise
ValueError
(
f
"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}"
)
def
_merge_heads
(
tensor
,
num_heads
,
attn_head_size
):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
if
len
(
tensor
.
shape
)
==
5
:
tensor
=
tensor
.
permute
(
0
,
1
,
3
,
2
,
4
)
.
contiguous
()
elif
len
(
tensor
.
shape
)
==
4
:
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
)
.
contiguous
()
else
:
raise
ValueError
(
f
"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}"
)
new_shape
=
tensor
.
size
()[:
-
2
]
+
(
num_heads
*
attn_head_size
,)
return
tensor
.
view
(
new_shape
)
def
_attn
(
query
,
key
,
value
,
causal_mask
,
masked_bias
,
attention_mask
=
None
,
scale_attn
=
None
):
attn_weights
=
torch
.
matmul
(
query
,
key
.
transpose
(
-
1
,
-
2
))
attn_weights
=
torch
.
where
(
causal_mask
,
attn_weights
,
masked_bias
.
to
(
attn_weights
.
dtype
))
attn_weights
=
attn_weights
/
scale_attn
if
attention_mask
is
not
None
:
attn_weights
=
attn_weights
+
attention_mask
attn_weights
=
F
.
softmax
(
attn_weights
,
dim
=-
1
)
attn_weights
=
attn_weights
.
to
(
value
.
dtype
)
attn_output
=
torch
.
matmul
(
attn_weights
,
value
)
.
to
(
value
.
dtype
)
return
attn_output
class
SelfAttention
(
nn
.
Module
):
# Code copied from HF, might want to sanity check later.
def
__init__
(
self
,
hidden_dim
,
n_head
,
device
,
dtype
):
nn
.
Module
.
__init__
(
self
)
max_positions
=
2049
bias
=
torch
.
tril
(
torch
.
ones
((
max_positions
,
max_positions
),
dtype
=
torch
.
uint8
,
requires_grad
=
False
))
.
view
(
1
,
1
,
max_positions
,
max_positions
)
.
bool
()
self
.
head_dim
=
hidden_dim
//
n_head
self
.
rotary_dim
=
self
.
head_dim
//
4
self
.
hidden_dim
=
hidden_dim
self
.
n_head
=
n_head
self
.
register_buffer
(
"scale_attn"
,
torch
.
sqrt
(
torch
.
tensor
(
self
.
head_dim
,
requires_grad
=
False
)
.
float
()))
self
.
register_buffer
(
"bias"
,
bias
)
self
.
register_buffer
(
"masked_bias"
,
torch
.
tensor
(
-
1e9
,
requires_grad
=
False
))
#-1e10 is what mtj uses.
attn_bias
=
False
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
out_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
def
forward
(
self
,
x
):
query
=
self
.
q_proj
(
x
)
key
=
self
.
k_proj
(
x
)
value
=
self
.
v_proj
(
x
)
query
=
_split_heads
(
query
,
self
.
n_head
,
self
.
head_dim
,
True
)
key
=
_split_heads
(
key
,
self
.
n_head
,
self
.
head_dim
,
True
)
value
=
_split_heads
(
value
,
self
.
n_head
,
self
.
head_dim
,
False
)
key
=
key
.
permute
(
0
,
2
,
1
,
3
)
query
=
query
.
permute
(
0
,
2
,
1
,
3
)
query_length
,
key_length
=
query
.
size
(
-
2
),
key
.
size
(
-
2
)
causal_mask
=
self
.
bias
[:,
:,
key_length
-
query_length
:
key_length
,
:
key_length
]
x
=
_attn
(
query
,
key
,
value
,
causal_mask
,
self
.
masked_bias
,
None
,
self
.
scale_attn
)
x
=
_merge_heads
(
x
,
self
.
n_head
,
self
.
head_dim
)
x
=
self
.
out_proj
(
x
)
return
x
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
hidden_dim
,
activation
,
device
,
dtype
):
nn
.
Module
.
__init__
(
self
)
self
.
ff1
=
nn
.
Linear
(
dim
,
hidden_dim
,
device
=
device
,
dtype
=
dtype
)
self
.
ff2
=
nn
.
Linear
(
hidden_dim
,
dim
,
device
=
device
,
dtype
=
dtype
)
self
.
activation
=
activation
def
forward
(
self
,
x
,
act_ck
=
False
):
x
=
self
.
ff1
(
x
)
if
act_ck
:
x
=
ck
(
self
.
activation
,
x
)
else
:
x
=
self
.
activation
(
x
)
x
=
token_shift_no_mix
(
x
)
x
=
self
.
ff2
(
x
)
return
x
class
GPTNELayer
(
nn
.
Module
):
def
__init__
(
self
,
attn
,
ff
,
hidden_dim
,
n_head
,
eps
,
activation
,
device
,
dtype
):
nn
.
Module
.
__init__
(
self
)
self
.
hidden_dim
=
hidden_dim
self
.
ln_preattn
=
nn
.
LayerNorm
(
hidden_dim
,
eps
=
eps
,
device
=
device
,
dtype
=
dtype
)
self
.
ff
=
ff
(
dim
=
hidden_dim
,
hidden_dim
=
hidden_dim
*
4
,
activation
=
activation
,
device
=
device
,
dtype
=
dtype
)
self
.
attn
=
attn
(
hidden_dim
=
hidden_dim
,
n_head
=
n_head
,
device
=
device
,
dtype
=
dtype
)
self
.
tick
=
True
def
forward
(
self
,
x
,
layer_id
=
None
,
hypernetwork
=
None
,
act_ck
=
False
,
diff_hypernets
=
False
,
interleaving_layers
=
False
,
every_n
=
5
):
residual
=
x
if
act_ck
:
x
=
ck
(
self
.
ln_preattn
,
x
)
attn_out
=
ck
(
self
.
attn
,
x
)
else
:
x
=
self
.
ln_preattn
(
x
)
attn_out
=
self
.
attn
(
x
)
if
hypernetwork
:
if
diff_hypernets
:
if
interleaving_layers
and
layer_id
%
every_n
==
0
:
if
self
.
tick
:
hyper_out
=
hypernetwork
[
0
](
x
)
self
.
tick
=
False
else
:
hyper_out
=
hypernetwork
[
1
](
x
)
self
.
tick
=
True
elif
layer_id
%
every_n
==
0
:
hyper_out
=
hypernetwork
[(
layer_id
//
every_n
)
-
1
](
x
)
else
:
if
layer_id
%
every_n
==
0
:
hyper_out
=
hypernetwork
(
x
)
ff_out
=
self
.
ff
(
x
,
act_ck
)
#order of addition matters, i had no idea... fixed a bug here.
x
=
attn_out
+
ff_out
+
residual
#x = residual + attn_out + ff_out -> doesn't match.
if
hypernetwork
and
layer_id
%
every_n
==
0
:
x
=
x
+
hyper_out
return
x
class
GPTNEModel
(
nn
.
Module
):
def
__init__
(
self
,
hidden_dim
,
n_layer
,
n_head
,
vocab_dim
,
eps
,
activation
=
gelu_new
,
Layer
=
GPTNELayer
,
device
=
"cuda"
,
dtype
=
torch
.
float16
,
**
kwargs
):
nn
.
Module
.
__init__
(
self
)
self
.
n_layer
=
n_layer
self
.
hidden_dim
=
hidden_dim
self
.
vocab_embed
=
nn
.
Embedding
(
vocab_dim
,
self
.
hidden_dim
,
device
=
device
,
dtype
=
dtype
)
self
.
ln_final
=
nn
.
LayerNorm
(
self
.
hidden_dim
,
eps
=
eps
,
device
=
device
,
dtype
=
dtype
)
self
.
layers
=
nn
.
ModuleList
([])
self
.
lm_head
=
nn
.
Linear
(
hidden_dim
,
vocab_dim
,
bias
=
True
)
for
_
in
range
(
n_layer
):
self
.
layers
.
append
(
Layer
(
attn
=
SelfAttention
,
ff
=
FeedForward
,
hidden_dim
=
hidden_dim
,
n_head
=
n_head
,
eps
=
eps
,
activation
=
activation
,
device
=
device
,
dtype
=
dtype
))
def
forward
(
self
,
x
,
hypernetwork
=
None
,
act_ck
=
False
):
x
=
self
.
get_embeds
(
x
,
hypernetwork
=
hypernetwork
,
act_ck
=
act_ck
)
x
=
self
.
lm_head
(
x
)
return
x
.
float
()
def
get_embeds
(
self
,
x
,
hypernetwork
=
None
,
act_ck
=
False
):
x
=
self
.
vocab_embed
(
x
)
for
layer_id
,
layer
in
enumerate
(
self
.
layers
):
x
=
layer
(
x
,
layer_id
=
layer_id
,
hypernetwork
=
hypernetwork
,
act_ck
=
act_ck
)
x
=
self
.
ln_final
(
x
)
return
x
class
GPTNEBaseLM
(
lm_base
.
BaseLM
):
def
__init__
(
self
,
config
=
None
,
lm
=
None
):
nn
.
Module
.
__init__
(
self
)
lm_base
.
BaseLM
.
__init__
(
self
,
config
,
lm
)
self
.
model_class
=
GPTNEModel
def
load_gpt_j
(
path
=
"models/6b"
,
state_dict
=
None
):
config
=
{
"n_layer"
:
28
,
"n_head"
:
16
,
"hidden_dim"
:
4096
,
"vocab_dim"
:
50400
,
"eps"
:
1e-5
}
model
=
GPTNEBaseLM
.
load
(
config
,
path
,
state_dict
)
return
model
train.py
View file @
86e815ab
...
@@ -4,39 +4,37 @@ import torch.nn.functional as F
...
@@ -4,39 +4,37 @@ import torch.nn.functional as F
import
torch.cuda.amp
as
amp
import
torch.cuda.amp
as
amp
import
torch.optim
as
optim
import
torch.optim
as
optim
from
pathlib
import
Path
from
pathlib
import
Path
from
lm_train
import
utils
from
torch.utils
import
data
from
torch.utils
import
data
from
basedformer
import
lm_base
,
optimizer
from
basedformer
import
optimizer
,
utils
,
gptj
,
noemblm
,
gpt2
import
yaml
import
yaml
import
sys
import
sys
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
time
import
time
import
wandb
import
wandb
import
numpy
as
np
import
numpy
as
np
import
os
model_config
=
{
model_config
=
{
"n_layer"
:
12
,
"n_layer"
:
3
,
"n_head"
:
1
2
,
"n_head"
:
1
6
,
"hidden_dim"
:
768
,
"hidden_dim"
:
1024
,
"vocab_dim"
:
50400
,
"vocab_dim"
:
50400
,
"eps"
:
1e-5
,
"eps"
:
1e-5
,
"activation"
:
gelu_new
,
"Layer"
:
GPTLayer
}
}
# we need 250 batch size to train the small GPT.
# we need 250 batch size to train the small GPT.
train_config
=
{
train_config
=
{
"data_path"
:
"/home/xuser/diffusionstorage/datasets/enwik9-gpt2-2049.map"
,
#
"data_path": "/home/xuser/diffusionstorage/datasets/enwik9-gpt2-2049.map",
#
"data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map",
"data_path"
:
"/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map"
,
#"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map",
"save_path"
:
"/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/fixedj"
,
"save_path"
:
"/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/owt2gptj-nopos-tokenshiftnomix-3L-16A-1024H"
,
"run_name"
:
"gpt-j-owt2-6b-preattn"
,
"do_save"
:
True
,
"run_name"
:
"gptj-nopos-tokenshiftnomix-owt2-72M-3L-16A-1024H-fp16AMP-512ctx-16bs-1e-4lrinit"
,
"lr"
:
1e-4
,
"lr"
:
1e-4
,
"end_lr"
:
1e-4
,
"end_lr"
:
1e-4
,
"warmup_steps"
:
5
0
,
"warmup_steps"
:
10
0
,
"bs"
:
1
2
,
"bs"
:
1
6
,
"gas"
:
1
0
,
"gas"
:
1
,
"seed"
:
69
,
"seed"
:
69
,
"save_every"
:
500
,
"save_every"
:
500
,
"amp"
:
True
,
"amp"
:
True
,
...
@@ -48,17 +46,37 @@ gas = train_config["gas"]
...
@@ -48,17 +46,37 @@ gas = train_config["gas"]
Path
(
train_config
[
"save_path"
])
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
Path
(
train_config
[
"save_path"
])
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
model
=
GPTModel
.
gpt2_init
(
model_config
)
.
cuda
()
.
float
()
#model = GPTModel.gpt2_init(model_config).cuda().float()
opt
=
optimizer
.
BasedOptimizer
(
model
.
parameters
(),
train_config
,
"adamw"
)
model
=
noemblm
.
GPTNEBaseLM
.
init
(
model_config
)
.
cuda
()
.
float
()
model
.
train
()
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
cp_list
=
sorted
(
os
.
listdir
(
train_config
[
"save_path"
]),
key
=
lambda
x
:
int
(
x
.
split
(
"_"
)[
-
1
]))
last_cp
=
Path
(
train_config
[
"save_path"
])
/
cp_list
[
-
1
]
if
len
(
cp_list
)
>
0
else
None
print
(
last_cp
)
if
last_cp
:
print
(
"Loading from step {}"
.
format
(
cp_list
[
-
1
]
.
split
(
"_"
)[
-
1
]))
model
.
load
(
model_config
,
last_cp
/
"lm"
,
strict
=
True
)
opt
=
optimizer
.
BasedOptimizer
.
load
(
model
.
parameters
(),
last_cp
/
"opt"
)
else
:
opt
=
optimizer
.
BasedOptimizer
(
model
.
parameters
(),
train_config
,
"adamw"
)
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
print
(
opt
.
curr_step
)
train_dataset
=
utils
.
FbDataset
(
2049
,
train_config
[
"data_path"
])
train_dataset
=
utils
.
FbDataset
(
2049
,
train_config
[
"data_path"
])
train_loader
=
data
.
DataLoader
(
train_dataset
,
batch_size
=
bs
*
gas
,
shuffle
=
False
,
num_workers
=
0
)
if
last_cp
:
wandb
.
init
(
project
=
"hypernetwork-tests"
,
name
=
train_config
[
"run_name"
],
config
=
{
**
train_config
,
**
model_config
})
train_dataset
.
skip
=
opt
.
curr_step
*
bs
*
gas
train_loader
=
data
.
DataLoader
(
train_dataset
,
batch_size
=
bs
*
gas
,
shuffle
=
False
,
num_workers
=
0
,
)
wandb
.
init
(
project
=
"basedformer-tests"
,
name
=
train_config
[
"run_name"
],
config
=
{
**
train_config
,
**
model_config
})
t
=
tqdm
(
train_loader
)
if
last_cp
:
curr_step
=
0
curr_step
=
opt
.
curr_step
else
:
curr_step
=
0
t
=
tqdm
(
train_loader
,
initial
=
curr_step
)
scaler
=
torch
.
cuda
.
amp
.
GradScaler
()
scaler
=
torch
.
cuda
.
amp
.
GradScaler
()
...
@@ -69,9 +87,10 @@ for input_ids, labels in t:
...
@@ -69,9 +87,10 @@ for input_ids, labels in t:
loss
=
0
loss
=
0
for
x
in
range
(
train_config
[
"gas"
]):
for
x
in
range
(
train_config
[
"gas"
]):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
train_config
[
"amp"
],
dtype
=
torch
.
float16
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
train_config
[
"amp"
],
dtype
=
torch
.
float16
):
logits
=
model
(
input_ids
[
x
*
bs
:(
x
+
1
)
*
bs
,
:]
.
cuda
(),
hypernetwork
=
None
,
act_ck
=
False
)
logits
=
model
.
lm
(
input_ids
[
x
*
bs
:(
x
+
1
)
*
bs
,
:
512
]
.
cuda
(),
act_ck
=
False
)
#print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
logits
=
logits
.
view
(
-
1
,
logits
.
shape
[
-
1
])
logits
=
logits
.
view
(
-
1
,
logits
.
shape
[
-
1
])
gas_labels
=
labels
[
x
*
bs
:(
x
+
1
)
*
bs
,
:]
.
contiguous
()
gas_labels
=
labels
[
x
*
bs
:(
x
+
1
)
*
bs
,
:
512
]
.
contiguous
()
gas_labels
=
gas_labels
.
view
(
-
1
)
gas_labels
=
gas_labels
.
view
(
-
1
)
gas_loss
=
F
.
cross_entropy
(
logits
,
gas_labels
)
gas_loss
=
F
.
cross_entropy
(
logits
,
gas_labels
)
...
@@ -88,7 +107,6 @@ for input_ids, labels in t:
...
@@ -88,7 +107,6 @@ for input_ids, labels in t:
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1
)
if
train_config
[
"loss_scale"
]:
if
train_config
[
"loss_scale"
]:
opt
.
step
(
scaler
=
scaler
)
opt
.
step
(
scaler
=
scaler
)
else
:
else
:
opt
.
step
()
opt
.
step
()
...
@@ -96,12 +114,27 @@ for input_ids, labels in t:
...
@@ -96,12 +114,27 @@ for input_ids, labels in t:
scaler
.
update
()
scaler
.
update
()
opt
.
zero_grad
()
opt
.
zero_grad
()
sec_per_step
=
(
time
.
perf_counter
()
-
timex
)
/
(
bs
*
gas
)
sec_per_step
=
(
time
.
perf_counter
()
-
timex
)
step_per_sec
=
(
1.
/
sec_per_step
)
step_per_sec
=
(
1.
/
sec_per_step
)
tokens_per_sec
=
step_per_sec
*
1024
tokens_per_sec
=
(
step_per_sec
*
1024
)
*
bs
*
gas
t
.
set_description
(
f
"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}"
)
t
.
set_description
(
f
"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}"
)
wandb
.
log
({
"train/loss"
:
loss
,
"train/tokens_per_sec"
:
tokens_per_sec
,
"train/sec_per_step"
:
sec_per_step
,
"train/step_per_sec"
:
step_per_sec
,
"train/lr"
:
opt
.
curr_lr
,
"train/loss_scale"
:
scaler
.
get_scale
()})
wandb
.
log
(
curr_step
+=
1
{
"train/loss"
:
loss
,
"train/tokens_per_sec"
:
tokens_per_sec
,
"train/sec_per_step"
:
sec_per_step
,
"train/step_per_sec"
:
step_per_sec
,
"train/lr"
:
opt
.
curr_lr
,
"train/loss_scale"
:
scaler
.
get_scale
()
},
step
=
curr_step
)
if
train_config
[
"do_save"
]:
if
curr_step
%
train_config
[
"save_every"
]
==
0
:
if
curr_step
%
train_config
[
"save_every"
]
==
0
:
model
.
save
(
train_config
[
"save_path"
]
+
f
"/{curr_step}"
)
save_folder
=
Path
(
train_config
[
"save_path"
])
/
f
"step_{curr_step}"
save_folder
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
model
.
save
(
save_folder
/
"lm"
)
opt
.
save
(
save_folder
/
"opt"
)
print
(
f
"Saved model at step {curr_step}"
)
print
(
f
"Saved model at step {curr_step}"
)
curr_step
+=
1
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment