Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
fb134d28
Commit
fb134d28
authored
Jul 06, 2022
by
novelailab
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
zero2 works
parent
9d27a5cc
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
33 deletions
+41
-33
basedformer/models/gptj.py
basedformer/models/gptj.py
+13
-10
basedformer/optimizer.py
basedformer/optimizer.py
+4
-0
finetune.py
finetune.py
+24
-23
No files found.
basedformer/models/gptj.py
View file @
fb134d28
...
@@ -95,15 +95,18 @@ class SelfAttention(nn.Module):
...
@@ -95,15 +95,18 @@ class SelfAttention(nn.Module):
sin
,
cos
=
fixed_pos_embedding
(
dim
=
self
.
rotary_dim
,
seq_len
=
max_positions
)
sin
,
cos
=
fixed_pos_embedding
(
dim
=
self
.
rotary_dim
,
seq_len
=
max_positions
)
self
.
register_buffer
(
"sin"
,
sin
)
self
.
register_buffer
(
"sin"
,
sin
)
self
.
register_buffer
(
"cos"
,
cos
)
self
.
register_buffer
(
"cos"
,
cos
)
self
.
fused_softmax
=
FusedScaleMaskSoftmax
(
if
self
.
config
.
masked_softmax_fusion
:
input_in_fp16
=
False
,
self
.
fused_softmax
=
FusedScaleMaskSoftmax
(
input_in_bf16
=
True
,
input_in_fp16
=
False
,
mask_func
=
attention_mask_func
,
input_in_bf16
=
True
,
scale
=
None
,
mask_func
=
attention_mask_func
,
softmax_in_fp32
=
False
,
scale
=
None
,
attn_mask_type
=
"causal"
,
softmax_in_fp32
=
False
,
scaled_masked_softmax_fusion
=
True
,
attn_mask_type
=
"causal"
,
)
scaled_masked_softmax_fusion
=
True
,
)
else
:
self
.
fused_softmax
=
None
def
forward
(
self
,
x
,
kv
=
None
,
cache
=
False
):
def
forward
(
self
,
x
,
kv
=
None
,
cache
=
False
):
B
,
S
,
H
=
x
.
shape
# batch, sequence, hidden_dim
B
,
S
,
H
=
x
.
shape
# batch, sequence, hidden_dim
...
@@ -242,7 +245,7 @@ class GPTJModel(base_lm.BaseModel):
...
@@ -242,7 +245,7 @@ class GPTJModel(base_lm.BaseModel):
'activation'
:
gelu_new
,
'activation'
:
gelu_new
,
'SelfAttention'
:
SelfAttention
,
'SelfAttention'
:
SelfAttention
,
'FeedForward'
:
FeedForward
,
'FeedForward'
:
FeedForward
,
'masked_softmax_fusion'
:
Fals
e
,
'masked_softmax_fusion'
:
Tru
e
,
}
}
base_lm
.
BaseModel
.
__init__
(
self
,
user_config
,
**
kwargs
)
base_lm
.
BaseModel
.
__init__
(
self
,
user_config
,
**
kwargs
)
if
self
.
config
.
masked_softmax_fusion
:
if
self
.
config
.
masked_softmax_fusion
:
...
...
basedformer/optimizer.py
View file @
fb134d28
...
@@ -77,6 +77,10 @@ class BasedOptimizer:
...
@@ -77,6 +77,10 @@ class BasedOptimizer:
eps
=
self
.
eps
,
eps
=
self
.
eps
,
)
)
elif
self
.
optimizer_name
==
"zero2"
:
from
apex.contrib.optimizers.distributed_fused_adam
import
DistributedFusedAdam
self
.
optimizer
=
DistributedFusedAdam
(
self
.
parameters
,
lr
=
0
,
weight_decay
=
self
.
weight_decay
,
betas
=
(
self
.
beta1
,
self
.
beta2
),
eps
=
self
.
eps
,
grad_sync_dtype
=
torch
.
float32
)
elif
self
.
optimizer_name
==
"adafactor"
:
elif
self
.
optimizer_name
==
"adafactor"
:
try
:
try
:
from
transformers.optimization
import
Adafactor
from
transformers.optimization
import
Adafactor
...
...
finetune.py
View file @
fb134d28
...
@@ -5,7 +5,7 @@ import torch.cuda.amp as amp
...
@@ -5,7 +5,7 @@ import torch.cuda.amp as amp
import
torch.optim
as
optim
import
torch.optim
as
optim
from
pathlib
import
Path
from
pathlib
import
Path
from
torch.utils
import
data
from
torch.utils
import
data
from
basedformer
import
optimizer
,
utils
,
lm_utils
from
basedformer
import
optimizer
,
utils
,
lm_utils
,
dataset
import
yaml
import
yaml
import
sys
import
sys
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -16,17 +16,11 @@ import os
...
@@ -16,17 +16,11 @@ import os
from
icecream
import
ic
from
icecream
import
ic
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
#from torch.nn.parallel import DistributedDataParallel as DDP
from
apex.parallel.distributed
import
DistributedDataParallel
as
DDP
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
dotmap
import
DotMap
from
dotmap
import
DotMap
import
argparse
import
argparse
from
torch.distributed.fsdp
import
(
FullyShardedDataParallel
,
CPUOffload
,
)
from
torch.distributed.fsdp.wrap
import
(
default_auto_wrap_policy
,
)
def
setup
(
rank
,
world_size
):
def
setup
(
rank
,
world_size
):
#os.environ['MASTER_ADDR'] = 'localhost'
#os.environ['MASTER_ADDR'] = 'localhost'
...
@@ -97,14 +91,19 @@ def fsdp_train(args, model, train_loader, opt):
...
@@ -97,14 +91,19 @@ def fsdp_train(args, model, train_loader, opt):
norm
=
norm
.
matmul
(
norm
.
transpose
(
-
1
,
-
2
))
norm
=
norm
.
matmul
(
norm
.
transpose
(
-
1
,
-
2
))
contrastive_loss
=
torch
.
matmul
(
hs
,
hs
.
transpose
(
-
2
,
-
1
))
.
div
(
norm
)
.
abs
()
.
mean
()
contrastive_loss
=
torch
.
matmul
(
hs
,
hs
.
transpose
(
-
2
,
-
1
))
.
div
(
norm
)
.
abs
()
.
mean
()
gas_loss
+=
contrastive_loss
*
args
.
contrastive_loss
gas_loss
+=
contrastive_loss
*
args
.
contrastive_loss
if
args
[
"loss_scale"
]:
if
args
[
"loss_scale"
]:
scaler
.
scale
(
gas_loss
)
.
backward
()
with
opt
.
optimizer
.
no_sync
():
scaler
.
scale
(
gas_loss
)
.
backward
()
else
:
else
:
gas_loss
.
backward
()
with
opt
.
optimizer
.
no_sync
():
gas_loss
.
backward
()
loss
+=
gas_loss
.
item
()
loss
+=
gas_loss
.
item
()
loss
=
loss
/
gas
loss
=
loss
/
gas
opt
.
optimizer
.
grad_sync
()
if
args
[
"loss_scale"
]:
if
args
[
"loss_scale"
]:
scaler
.
unscale_
(
opt
.
optimizer
)
scaler
.
unscale_
(
opt
.
optimizer
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1
)
...
@@ -116,10 +115,10 @@ def fsdp_train(args, model, train_loader, opt):
...
@@ -116,10 +115,10 @@ def fsdp_train(args, model, train_loader, opt):
if
args
[
"loss_scale"
]:
if
args
[
"loss_scale"
]:
scaler
.
update
()
scaler
.
update
()
#
opt.zero_grad()
opt
.
zero_grad
()
model
.
zero_grad
(
set_to_none
=
True
)
#
model.zero_grad(set_to_none=True)
sec_per_step
=
(
time
.
perf_counter
()
-
timex
)
sec_per_step
=
(
time
.
perf_counter
()
-
timex
)
flops
=
get_flops
(
args
,
model
.
module
,
sec_per_step
)
flops
=
get_flops
(
args
,
model
,
sec_per_step
)
step_per_sec
=
(
1.
/
sec_per_step
)
step_per_sec
=
(
1.
/
sec_per_step
)
tokens_per_sec
=
(
step_per_sec
*
2048
)
*
bs
*
gas
*
world_size
tokens_per_sec
=
(
step_per_sec
*
2048
)
*
bs
*
gas
*
world_size
batch_size
=
bs
*
gas
*
world_size
batch_size
=
bs
*
gas
*
world_size
...
@@ -153,15 +152,17 @@ def main(rank, global_rank, world_size, args):
...
@@ -153,15 +152,17 @@ def main(rank, global_rank, world_size, args):
setup
(
rank
,
world_size
)
setup
(
rank
,
world_size
)
Path
(
args
[
"save_path"
])
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
Path
(
args
[
"save_path"
])
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
model
=
lm_utils
.
load_from_path
(
"pretrained/gpt-j-base"
)
.
float
()
.
to
(
rank
)
model
=
lm_utils
.
load_from_path
(
"/home/xuser/nvme1/pretrained/gpt-j-base"
)
.
half
()
.
to
(
rank
)
fsdp_model
=
DDP
(
model
,
device_ids
=
[
rank
],
output_device
=
rank
,
gradient_as_bucket_view
=
True
)
#fsdp_model = DDP(model, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
#fsdp_model = DDP(model)
fsdp_model
=
model
utils
.
print_parameters
(
fsdp_model
)
utils
.
print_parameters
(
fsdp_model
)
ic
(
"model loaded"
)
ic
(
"model loaded"
)
opt
=
optimizer
.
BasedOptimizer
(
fsdp_model
.
parameters
(),
args
,
"zero
1
"
)
opt
=
optimizer
.
BasedOptimizer
(
fsdp_model
.
parameters
(),
args
,
"zero
2
"
)
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
print
(
opt
.
curr_step
)
print
(
opt
.
curr_step
)
train_dataset
=
utils
.
ShardedDataset
(
2049
,
args
[
"data_path"
],
world_size
=
world_size
,
rank
=
global_rank
)
train_dataset
=
dataset
.
ShardedDataset
(
2049
,
args
[
"data_path"
],
world_size
=
world_size
,
rank
=
global_rank
)
train_loader
=
data
.
DataLoader
(
train_dataset
,
batch_size
=
bs
*
gas
,
shuffle
=
False
,
num_workers
=
0
,
)
train_loader
=
data
.
DataLoader
(
train_dataset
,
batch_size
=
bs
*
gas
,
shuffle
=
False
,
num_workers
=
0
,
)
if
global_rank
==
0
:
if
global_rank
==
0
:
wandb
.
init
(
project
=
"basedformer-tests"
,
name
=
args
[
"run_name"
],
config
=
{
**
args
,
**
model
.
config
})
wandb
.
init
(
project
=
"basedformer-tests"
,
name
=
args
[
"run_name"
],
config
=
{
**
args
,
**
model
.
config
})
...
@@ -172,21 +173,21 @@ def main(rank, global_rank, world_size, args):
...
@@ -172,21 +173,21 @@ def main(rank, global_rank, world_size, args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
train_config
=
{
train_config
=
{
"data_path"
:
"dataset/sigurd-1G.map"
,
"data_path"
:
"
/home/xuser/nvme1/
dataset/sigurd-1G.map"
,
"save_path"
:
"models/gptj-sigurd-1G-vanilla"
,
"save_path"
:
"models/gptj-sigurd-1G-vanilla"
,
"do_save"
:
Tru
e
,
"do_save"
:
Fals
e
,
"run_name"
:
"gptj-sigurd-1G-vanilla"
,
"run_name"
:
"gptj-sigurd-1G-vanilla"
,
"lr"
:
6e-5
,
"lr"
:
6e-5
,
"end_lr"
:
3e-5
,
"end_lr"
:
3e-5
,
"warmup_steps"
:
100
,
"warmup_steps"
:
100
,
"anneal_steps"
:
7850
,
"anneal_steps"
:
7850
,
"bs"
:
2
,
"bs"
:
2
,
"gas"
:
2
,
"gas"
:
8
,
"seed"
:
69
,
"seed"
:
69
,
"save_every"
:
500
,
"save_every"
:
500
,
"amp"
:
Tru
e
,
"amp"
:
Fals
e
,
"loss_scale"
:
True
,
"loss_scale"
:
True
,
"cast_to"
:
torch
.
float16
,
"cast_to"
:
torch
.
b
float16
,
"contrastive_loss"
:
False
,
"contrastive_loss"
:
False
,
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment