Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
347ef912
Commit
347ef912
authored
Mar 21, 2022
by
novelailab
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
pickle optimizer save/load
parent
b07251f0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
4 deletions
+25
-4
lm_train/optimizer.py
lm_train/optimizer.py
+16
-1
train.py
train.py
+9
-3
No files found.
lm_train/optimizer.py
View file @
347ef912
from
torch
import
optim
from
torch
import
optim
import
numpy
as
np
import
numpy
as
np
import
torch
from
dotmap
import
DotMap
import
pickle
#Based Optimizer
#Based Optimizer
def
lr_schedule
(
step
,
warmup_steps
,
anneal_steps
,
lr
,
end_lr
):
def
lr_schedule
(
step
,
warmup_steps
,
anneal_steps
,
lr
,
end_lr
):
...
@@ -61,4 +64,16 @@ class BasedOptimizer:
...
@@ -61,4 +64,16 @@ class BasedOptimizer:
print
(
f
"weight_decay: {str(self.weight_decay)}"
)
print
(
f
"weight_decay: {str(self.weight_decay)}"
)
print
(
f
"step: {str(self.curr_step)}"
)
print
(
f
"step: {str(self.curr_step)}"
)
if
self
.
curr_step
!=
0
:
if
self
.
curr_step
!=
0
:
print
(
f
"curr_lr: {str(self.get_current_lr())}"
)
print
(
f
"curr_lr: {str(self.get_current_lr())}"
)
\ No newline at end of file
def
save
(
self
,
path
):
torch
.
save
(
self
.
optimizer
.
state_dict
(),
path
)
with
open
(
path
,
'wb'
)
as
f
:
pickle
.
dump
(
self
,
f
)
@
classmethod
def
load
(
cls
,
path
):
with
open
(
path
,
'rb'
)
as
f
:
based_optimizer
=
pickle
.
load
(
f
)
based_optimizer
.
optimizer
.
load_state_dict
(
torch
.
load
(
path
))
return
based_optimizer
\ No newline at end of file
train.py
View file @
347ef912
...
@@ -33,19 +33,21 @@ train_config = {
...
@@ -33,19 +33,21 @@ train_config = {
"bs"
:
16
,
"bs"
:
16
,
"gas"
:
16
,
"gas"
:
16
,
"seed"
:
69
,
"seed"
:
69
,
"save_every"
:
50
,
}
}
bs
=
train_config
[
"bs"
]
bs
=
train_config
[
"bs"
]
gas
=
train_config
[
"gas"
]
gas
=
train_config
[
"gas"
]
model
=
GPTModel
.
neox_init
(
model_config
)
.
cuda
()
.
bfloat16
()
model
=
GPTModel
.
neox_init
(
model_config
)
.
cuda
()
.
bfloat16
()
opt
=
optimizer
.
BasedOptimizer
(
model
.
parameters
(),
train_config
,
"adamw"
)
opt
=
optimizer
.
BasedOptimizer
(
model
.
parameters
(),
train_config
,
"adamw"
)
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel.
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel
, outputting hidden states from the get_logits function
.
train_dataset
=
utils
.
FbDataset
(
2049
,
train_config
[
"data_path"
])
train_dataset
=
utils
.
FbDataset
(
2049
,
train_config
[
"data_path"
])
train_loader
=
data
.
DataLoader
(
train_dataset
,
batch_size
=
bs
*
gas
,
shuffle
=
False
,
num_workers
=
0
)
train_loader
=
data
.
DataLoader
(
train_dataset
,
batch_size
=
bs
*
gas
,
shuffle
=
False
,
num_workers
=
0
)
wandb
.
init
(
project
=
"basedformer-tests"
,
name
=
train_config
[
"run_name"
],
config
=
{
**
train_config
,
**
model_config
})
wandb
.
init
(
project
=
"basedformer-tests"
,
name
=
train_config
[
"run_name"
],
config
=
{
**
train_config
,
**
model_config
})
t
=
tqdm
(
train_loader
)
t
=
tqdm
(
train_loader
)
curr_step
=
0
for
input_ids
,
labels
in
t
:
for
input_ids
,
labels
in
t
:
timex
=
time
.
perf_counter
()
timex
=
time
.
perf_counter
()
input_ids
=
input_ids
.
cuda
()
input_ids
=
input_ids
.
cuda
()
...
@@ -59,7 +61,7 @@ for input_ids, labels in t:
...
@@ -59,7 +61,7 @@ for input_ids, labels in t:
gas_loss
=
F
.
cross_entropy
(
logits
,
gas_labels
)
gas_loss
=
F
.
cross_entropy
(
logits
,
gas_labels
)
gas_loss
.
backward
()
gas_loss
.
backward
()
loss
+=
gas_loss
.
item
()
loss
+=
gas_loss
.
item
()
loss
=
loss
/
gas
loss
=
loss
/
gas
opt
.
step
()
opt
.
step
()
opt
.
zero_grad
()
opt
.
zero_grad
()
...
@@ -67,4 +69,8 @@ for input_ids, labels in t:
...
@@ -67,4 +69,8 @@ for input_ids, labels in t:
step_per_sec
=
(
1.
/
sec_per_step
)
step_per_sec
=
(
1.
/
sec_per_step
)
tokens_per_sec
=
step_per_sec
*
2048
tokens_per_sec
=
step_per_sec
*
2048
t
.
set_description
(
f
"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}"
)
t
.
set_description
(
f
"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}"
)
wandb
.
log
({
"train/loss"
:
loss
,
"train/tokens_per_sec"
:
tokens_per_sec
,
"train/sec_per_step"
:
sec_per_step
,
"train/step_per_sec"
:
step_per_sec
,
"train/lr"
:
opt
.
curr_lr
})
wandb
.
log
({
"train/loss"
:
loss
,
"train/tokens_per_sec"
:
tokens_per_sec
,
"train/sec_per_step"
:
sec_per_step
,
"train/step_per_sec"
:
step_per_sec
,
"train/lr"
:
opt
.
curr_lr
})
\ No newline at end of file
curr_step
+=
1
if
curr_step
%
train_config
[
"save_every"
]
==
0
:
model
.
save
(
train_config
[
"save_path"
])
print
(
f
"Saved model at step {curr_step}"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment