Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
a9eca288
Commit
a9eca288
authored
Mar 18, 2022
by
novelailab
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
it trains!
parent
9a167649
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
187 additions
and
97 deletions
+187
-97
.gitignore
.gitignore
+1
-0
act_ck.py
act_ck.py
+41
-10
configs/train.yaml
configs/train.yaml
+3
-0
gelu_test.py
gelu_test.py
+22
-0
lm_train/optimizer.py
lm_train/optimizer.py
+20
-12
lm_train/utils.py
lm_train/utils.py
+4
-17
main.py
main.py
+38
-3
test_pyfra.py
test_pyfra.py
+3
-2
train.py
train.py
+55
-53
No files found.
.gitignore
View file @
a9eca288
...
...
@@ -131,3 +131,4 @@ dmypy.json
models
gptjconvert
j6b_vanilla
wandb
\ No newline at end of file
act_ck.py
View file @
a9eca288
...
...
@@ -61,21 +61,52 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True
def
rndinput
(
shape
):
return
torch
.
randint
(
0
,
50256
,
shape
)
.
long
()
.
cuda
()
def
forward
(
model
,
x
):
out
=
model
.
get_logits
(
x
,
act_ck
=
False
)
@
torch
.
no_grad
()
def
forward
(
model
,
x
,
hypernetwork
=
None
):
out
=
model
.
get_logits
(
x
,
hypernetwork
=
hypernetwork
,
act_ck
=
True
)
print
(
out
.
shape
)
print
(
"torch.cuda.memory_allocated:
%
fGB"
%
(
torch
.
cuda
.
memory_allocated
(
0
)
/
1024
/
1024
/
1024
))
loss
=
torch
.
nn
.
CrossEntropyLoss
()(
out
,
out
)
loss
.
backward
()
model
.
zero_grad
()
print
(
"torch.cuda.memory_allocated:
%
fGB"
%
(
torch
.
cuda
.
memory_allocated
(
0
)
/
1024
/
1024
/
1024
))
#print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
#loss = torch.nn.CrossEntropyLoss()(out, out)
#loss.backward()
#model.zero_grad()
#print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
class
HyperNetwork
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
num_layers
):
super
()
.
__init__
()
embed_dim
=
hidden_size
self
.
linear
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
True
)
state
=
self
.
state_dict
()
for
k
in
state
:
state
[
k
]
=
state
[
k
]
*
1
/
math
.
sqrt
(
2
*
num_layers
)
self
.
load_state_dict
(
state
)
def
forward
(
self
,
hidden_states
):
hidden_states
=
self
.
linear
(
hidden_states
)
hidden_states
=
hidden_states
.
mul
(
torch
.
sigmoid
(
hidden_states
))
return
hidden_states
def
main
():
model
=
init_1_3b
()
.
cuda
()
.
half
()
shape
=
(
1
,
2048
)
model
=
init_6b
()
.
cuda
()
.
half
()
for
param
in
model
.
parameters
():
param
.
requires_grad
=
False
for
param
in
model
.
vocab_embed
.
parameters
():
param
.
requires_grad
=
True
for
x
in
model
.
layers
:
for
param
in
x
.
ln_preattn
.
parameters
():
param
.
requires_grad
=
True
hypernetwork
=
HyperNetwork
(
4096
,
28
)
.
cuda
()
.
half
()
hypernetwork
.
train
()
shape
=
(
1
,
1
)
#print(model(x).shape)
print
(
"PyTorch Eager"
)
timeit
(
r
=
1
,
n
=
2
,
func
=
lambda
:
forward
(
model
,
rndinput
(
shape
)),
do_tqdm
=
False
,
first
=
False
)
timeit
(
r
=
1
,
n
=
2
,
func
=
lambda
:
forward
(
model
,
rndinput
(
shape
)
,
hypernetwork
),
do_tqdm
=
False
,
first
=
False
)
if
__name__
==
"__main__"
:
main
()
configs/train.yaml
0 → 100644
View file @
a9eca288
{
"
lr"
:
1.0e-4
,
}
\ No newline at end of file
gelu_test.py
0 → 100644
View file @
a9eca288
import
torch
import
math
@
torch
.
jit
.
script
def
gelu_new
(
x
):
return
0.5
*
x
*
(
1.0
+
torch
.
tanh
(
math
.
sqrt
(
2.0
/
math
.
pi
)
*
(
x
+
0.044715
*
torch
.
pow
(
x
,
3.0
))))
def
gelu_slow
(
x
):
return
0.5
*
x
*
(
1.0
+
torch
.
tanh
(
math
.
sqrt
(
2.0
/
math
.
pi
)
*
(
x
+
0.044715
*
torch
.
pow
(
x
,
3.0
))))
def
gelu_trace
(
x
):
return
0.5
*
x
*
(
1.0
+
torch
.
tanh
(
math
.
sqrt
(
2.0
/
math
.
pi
)
*
(
x
+
0.044715
*
torch
.
pow
(
x
,
3.0
))))
def
gelu_involved
(
x
):
return
gelu_new
(
x
)
#torch.jit.trace gelu
#code:
gelu_traced
=
torch
.
jit
.
trace
(
gelu_involved
,
torch
.
randn
(
1
,
128
,
128
))
x
=
torch
.
rand
(
1
,
128
,
128
)
assert
torch
.
allclose
(
gelu_new
(
x
),
gelu_involved
(
x
))
lm_train/optimizer.py
View file @
a9eca288
...
...
@@ -11,18 +11,26 @@ def lr_schedule(step, warmup_steps, anneal_steps, lr, end_lr):
class
BasedOptimizer
:
def
__init__
(
self
,
parameters
,
config
,
optimizer
):
self
.
lr
=
config
[
"lr"
]
self
.
end_lr
=
config
[
"end_lr"
]
if
"end_lr"
in
config
else
self
.
lr
self
.
warmup_steps
=
config
[
"warmup_steps"
]
if
"warmup_steps"
in
config
else
1
self
.
anneal_steps
=
config
[
"anneal_steps"
]
if
"anneal_steps"
in
config
else
1
self
.
total_steps
=
config
[
"total_steps"
]
if
"total_steps"
in
config
else
None
self
.
weight_decay
=
config
[
"weight_decay"
]
if
"weight_decay"
in
config
else
0
self
.
tokens
=
config
[
"tokens"
]
if
"tokens"
in
config
else
None
self
.
epochs
=
config
[
"epochs"
]
if
"epochs"
in
config
else
None
# tokens and epochs should not be here. calculate it somewhere else and find how many steps, then pass to the BasedOptimizer
self
.
beta1
=
config
[
"beta1"
]
if
"beta1"
in
config
else
0.9
self
.
beta2
=
config
[
"beta2"
]
if
"beta2"
in
config
else
0.95
self
.
eps
=
config
[
"eps"
]
if
"eps"
in
config
else
1e-4
defaults
=
{
"lr"
:
6e-4
,
"end_lr"
:
6e-4
,
"warmup_steps"
:
1
,
"anneal_steps"
:
1
,
"total_steps"
:
None
,
"weight_decay"
:
0
,
"tokens"
:
None
,
"epochs"
:
None
,
"beta1"
:
0.9
,
"beta2"
:
0.95
,
"eps"
:
1e-4
,
}
for
k
,
v
in
defaults
.
items
():
setattr
(
self
,
k
,
v
)
for
k
,
v
in
config
.
items
():
setattr
(
self
,
k
,
v
)
self
.
max_lr
=
False
self
.
curr_step
=
0
self
.
curr_lr
=
0
...
...
lm_train/utils.py
View file @
a9eca288
...
...
@@ -6,30 +6,17 @@ import torch
# Does this work with other block_sizes? doesn't seem to.
class
FbDataset
(
data
.
Dataset
):
def
__init__
(
self
,
block_size
,
map_file
,
max_samples
=
None
):
self
.
half_blocks
=
False
if
block_size
is
not
None
and
int
(
block_size
)
<
2048
:
self
.
half_blocks
=
True
self
.
npz
=
np
.
memmap
(
map_file
,
mode
=
"r"
,
dtype
=
"uint16"
)
.
reshape
((
-
1
,
2048
))
self
.
npz
=
np
.
memmap
(
map_file
,
mode
=
"r"
,
dtype
=
"uint16"
)
.
reshape
((
-
1
,
block_size
))
self
.
samples
=
self
.
npz
.
shape
[
0
]
if
self
.
half_blocks
:
self
.
samples
*=
2
if
not
max_samples
is
None
:
if
max_samples
is
not
None
:
self
.
samples
=
min
(
self
.
samples
,
int
(
max_samples
))
self
.
skip
=
0
def
__len__
(
self
):
return
self
.
samples
def
__getitem__
(
self
,
_id
):
nth
=
_id
+
self
.
skip
offset
=
0
length
=
2048
if
self
.
half_blocks
:
nth
=
_id
//
2
offset
=
1024
*
(
_id
%
2
)
length
=
1024
data
=
torch
.
tensor
(
self
.
npz
[
nth
][
offset
:
offset
+
length
]
.
astype
(
np
.
int64
))
return
(
data
,
data
)
data
=
torch
.
tensor
(
self
.
npz
[
nth
]
.
astype
(
np
.
int64
))
return
(
data
[:
-
1
],
data
[
1
:])
# Make loading models faster by not letting pytorch initialize the weights.
# Usage: no_init(lambda: load_model(...))
...
...
main.py
View file @
a9eca288
...
...
@@ -72,7 +72,6 @@ class SplitCheckpoint(MutableMapping):
def
get_logits
(
x
,
embedding
):
return
embedding
(
x
)
@
torch
.
jit
.
script
def
gelu_new
(
x
):
return
0.5
*
x
*
(
1.0
+
torch
.
tanh
(
math
.
sqrt
(
2.0
/
math
.
pi
)
*
(
x
+
0.044715
*
torch
.
pow
(
x
,
3.0
))))
...
...
@@ -266,8 +265,8 @@ class GPTModel(nn.Module):
x
=
self
.
ln_final
(
x
)
return
x
def
get_logits
(
self
,
x
,
act_ck
=
False
):
x
=
self
.
forward
(
x
,
act_ck
=
act_ck
)
def
get_logits
(
self
,
x
,
hypernetwork
=
None
,
act_ck
=
False
):
x
=
self
.
forward
(
x
,
hypernetwork
=
hypernetwork
,
act_ck
=
act_ck
)
x
=
self
.
lm_head
(
x
)
return
x
.
float
()
...
...
@@ -285,6 +284,22 @@ class GPTModel(nn.Module):
model
=
cls
(
**
config
)
return
model
@
classmethod
def
neox_init
(
cls
,
config
):
model
=
cls
(
**
config
)
modules
=
[
*
model
.
layers
[:
-
1
],
model
.
vocab_embed
,
model
.
ln_final
,
model
.
lm_head
]
init
=
small_init_method
(
config
[
"hidden_dim"
])
for
module
in
modules
:
for
param
in
module
.
parameters
():
init
(
param
)
last_layer
=
model
.
layers
[
-
1
]
last_layer_init
=
wang_init_method
(
config
[
"n_layer"
],
config
[
"hidden_dim"
])
for
param
in
last_layer
.
parameters
():
last_layer_init
(
param
)
return
model
def
save
(
self
,
path
):
try
:
os
.
mkdir
(
path
)
except
:
pass
...
...
@@ -297,6 +312,26 @@ class GPTModel(nn.Module):
# TODO: Do we want to have the LM head as a seperate Class? Or just a function? I think we might be better off with a function here and maybe
# also for the self attention, we can just write a function that gets fed in the q, k, v.
def
wang_init_method
(
n_layers
,
dim
):
std
=
2
/
n_layers
/
math
.
sqrt
(
dim
)
def
init_
(
tensor
):
return
torch
.
nn
.
init
.
normal_
(
tensor
,
mean
=
0.0
,
std
=
std
)
return
init_
# Stolen from NeoX. For the 20B run wang_init used on the output layer and small_init on rest of the layers.
def
small_init_method
(
dim
):
"""Fills the input Tensor with values according to the method described in Transformers without Tears: Improving
the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2010), using a normal distribution."""
std
=
math
.
sqrt
(
2
/
(
5
*
dim
))
def
init_
(
tensor
):
return
torch
.
nn
.
init
.
normal_
(
tensor
,
mean
=
0.0
,
std
=
std
)
return
init_
def
load_gpt_j
(
path
=
"models/6b"
,
state_dict
=
None
):
config
=
{
"n_layer"
:
28
,
...
...
test_pyfra.py
View file @
a9eca288
...
...
@@ -7,7 +7,7 @@ dry = False
config_obj
=
KubeConfig
()
config_obj
.
set_name
(
name
)
config_obj
.
set_gpu
(
gpu_name
=
GPU
.
RTX_A5000
,
amount
=
1
)
config_obj
.
set_gpu
(
gpu_name
=
GPU
.
A100_PCIE_40GB
,
amount
=
1
)
config_obj
.
set_ram
(
16
)
config_obj
.
set_cpu
(
4
)
config_obj
.
dry_run
(
dry
)
...
...
@@ -23,7 +23,8 @@ env1.sh('pip install /home/xuser/hugessd/pytorch/torch-1.10.1+cu113-cp38-cp38-li
env1
.
sh
(
'pip install einops numpy'
)
env1
.
sh
(
'pip install tqdm'
)
env1
.
sh
(
'pip install /home/xuser/diffusionstorage/workspace/finetune/pokepls/transformers-repo'
)
env1
.
sh
(
'pip3 install einops==0.4.1'
)
env1
.
sh
(
'pip3 install einops==0.4.1 pyyaml wandb'
)
env1
.
sh
(
'wandb login 21a9442d42a35e15ce421f2b702ec58508b9adc4'
)
with
always_rerun
():
print
(
f
"Running {sys.argv[1]}"
)
path
.
sh
(
f
'python3 {sys.argv[1]}'
)
\ No newline at end of file
train.py
View file @
a9eca288
...
...
@@ -3,60 +3,62 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.cuda.amp
as
amp
import
torch.optim
as
optim
from
lm_train
import
optimizer
,
utils
from
torch.utils
import
data
from
main
import
*
import
yaml
import
sys
from
tqdm
import
tqdm
import
time
import
wandb
#Based Optimizer
class
BasedOptimizer
:
def
__init__
(
self
,
model
,
config
,
optimizer
):
self
.
min_lr
=
config
[
"min_lr"
]
if
"min_lr"
in
config
else
1e-06
self
.
warmup_end
=
config
[
"lr"
]
if
"lr"
in
config
else
5e-06
self
.
warmup_init
=
config
[
"warmup_init"
]
if
"warmup_init"
in
config
else
0
self
.
warmup_steps
=
config
[
"warmup_steps"
]
if
"warmup_steps"
in
config
else
1
self
.
total_steps
=
config
[
"total_steps"
]
if
"total_steps"
in
config
else
None
self
.
weight_decay
=
config
[
"weight_decay"
]
if
"weight_decay"
in
config
else
0
self
.
start_step
=
config
[
"start_step"
]
if
"start_step"
in
config
else
0
self
.
curr_step
=
self
.
start_step
self
.
curr_lr
=
0
model_config
=
{
"n_layer"
:
12
,
"n_head"
:
12
,
"hidden_dim"
:
768
,
"vocab_dim"
:
50400
,
"eps"
:
1e-5
,
"activation"
:
gelu_new
,
"Layer"
:
GPTLayer
}
optim_func
=
optim
.
AdamW
# we need 250 batch size to train the small GPT.
train_config
=
{
"lr"
:
6e-4
,
"end_lr"
:
6e-4
,
"warmup_steps"
:
100
,
"bs"
:
16
,
"gas"
:
2
,
"seed"
:
69
,
}
bs
=
train_config
[
"bs"
]
gas
=
train_config
[
"gas"
]
model
=
GPTModel
.
neox_init
(
model_config
)
.
cuda
()
.
bfloat16
()
opt
=
optimizer
.
BasedOptimizer
(
model
.
parameters
(),
train_config
,
"adamw"
)
train_dataset
=
utils
.
FbDataset
(
2049
,
"sigurd_v5_2049.map"
)
train_loader
=
data
.
DataLoader
(
train_dataset
,
batch_size
=
bs
*
gas
,
shuffle
=
False
,
num_workers
=
0
)
wandb
.
init
(
project
=
"basedformer-tests"
,
name
=
"sigurd_v5_2049"
)
self
.
optimizers
=
optim_func
(
model
.
parameters
(),
lr
=
self
.
warmup_init
,
weight_decay
=
self
.
weight_decay
,
betas
=
config
[
"betas"
],
eps
=
config
[
"eps"
])
t
=
tqdm
(
train_loader
)
for
input_ids
,
labels
in
t
:
timex
=
time
.
perf_counter
()
input_ids
=
input_ids
.
cuda
()
labels
=
labels
.
cuda
()
loss
=
0
for
x
in
range
(
train_config
[
"gas"
]):
logits
=
model
.
get_logits
(
input_ids
[
x
*
bs
:(
x
+
1
)
*
bs
,
:]
.
cuda
(),
hypernetwork
=
None
,
act_ck
=
True
)
logits
=
logits
.
view
(
-
1
,
logits
.
shape
[
-
1
])
gas_labels
=
labels
[
x
*
bs
:(
x
+
1
)
*
bs
,
:]
gas_labels
=
gas_labels
.
view
(
-
1
)
gas_loss
=
F
.
cross_entropy
(
logits
,
gas_labels
)
gas_loss
.
backward
()
loss
+=
gas_loss
.
item
()
def
get_current_lr
(
self
):
cosine_lr
=
self
.
min_lr
+
0.5
*
(
self
.
warmup_end
-
self
.
min_lr
)
*
(
1
+
math
.
cos
(
math
.
pi
*
min
(
1.0
,
max
(
0
,
self
.
curr_step
-
self
.
warmup_steps
)
/
(
self
.
total_steps
-
self
.
warmup_steps
))))
target_lr
=
self
.
warmup_end
if
self
.
curr_step
<
self
.
warmup_steps
else
cosine_lr
return
inter
(
self
.
warmup_init
,
target_lr
,
max
(
0
,
self
.
curr_step
-
self
.
start_step
)
/
max
(
1
,
self
.
warmup_steps
))
return
min
(
self
.
end_lr
*
(
self
.
curr_step
/
self
.
warmup_steps
),
self
.
end_lr
)
def
backward
(
self
,
loss
):
self
.
optimizers
[
0
]
.
backward
(
loss
,
update_master_grads
=
False
)
#loss.backward()
def
step
(
self
,
scaler
=
None
):
self
.
curr_lr
=
self
.
get_current_lr
()
for
optimizer
in
self
.
optimizers
:
for
paramx
in
optimizer
.
param_groups
:
paramx
[
'lr'
]
=
self
.
curr_lr
optimizer
.
update_master_grads
()
if
scaler
:
for
optimizer
in
self
.
optimizers
:
scaler
.
step
(
optimizer
)
else
:
optimizer
.
step
()
self
.
curr_step
+=
1
def
zero_grad
(
self
):
for
optimizer
in
self
.
optimizers
:
optimizer
.
zero_grad
()
def
print_info
(
self
):
print
(
f
"min_lr: {str(self.min_lr)}"
)
print
(
f
"warmup_end: {str(self.warmup_end)}"
)
print
(
f
"warmup_init: {str(self.warmup_init)}"
)
print
(
f
"warmup_steps: {str(self.warmup_steps)}"
)
print
(
f
"start_step: {str(self.start_step)}"
)
print
(
f
"total_steps: {str(self.total_steps)}"
)
print
(
f
"weight_decay: {str(self.weight_decay)}"
)
print
(
f
"step: {str(self.curr_step)}"
)
print
(
f
"curr_lr: {str(self.get_current_lr())}"
)
\ No newline at end of file
loss
=
loss
/
gas
opt
.
step
()
opt
.
zero_grad
()
sec_per_step
=
(
time
.
perf_counter
()
-
timex
)
/
(
bs
*
gas
)
step_per_sec
=
(
1.
/
sec_per_step
)
tokens_per_sec
=
step_per_sec
*
2048
t
.
set_description
(
f
"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}"
)
wandb
.
log
({
"train/loss"
:
loss
,
"train/tokens_per_sec"
:
tokens_per_sec
,
"train/sec_per_step"
:
sec_per_step
,
"train/step_per_sec"
:
step_per_sec
,
"train/lr"
:
opt
.
curr_lr
})
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment