Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
971ed5dc
Commit
971ed5dc
authored
Apr 09, 2022
by
novelailab
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
optimizer save load completed
parent
41b51369
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
80 additions
and
35 deletions
+80
-35
basedformer/lm_base.py
basedformer/lm_base.py
+1
-2
basedformer/optimizer.py
basedformer/optimizer.py
+51
-32
hypertrain.py
hypertrain.py
+0
-1
scripts/test_optimizer.py
scripts/test_optimizer.py
+28
-0
No files found.
basedformer/lm_base.py
View file @
971ed5dc
...
@@ -6,7 +6,7 @@ from basedformer import gptj
...
@@ -6,7 +6,7 @@ from basedformer import gptj
import
os
import
os
import
json
import
json
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
pathlib
import
Path
'''
'''
BaseLM config dataclass:
BaseLM config dataclass:
model_config = {
model_config = {
...
@@ -27,7 +27,6 @@ class BaseLMConfig():
...
@@ -27,7 +27,6 @@ class BaseLMConfig():
vocab_dim
:
int
vocab_dim
:
int
eps
:
float
eps
:
float
#Having common BaseLM functionality in this class instead of the torch LM itself makes sense.
#Having common BaseLM functionality in this class instead of the torch LM itself makes sense.
class
BaseLM
(
nn
.
Module
):
class
BaseLM
(
nn
.
Module
):
def
__init__
(
self
,
config
=
None
,
lm
=
None
):
def
__init__
(
self
,
config
=
None
,
lm
=
None
):
...
...
basedformer/optimizer.py
View file @
971ed5dc
...
@@ -3,18 +3,29 @@ import numpy as np
...
@@ -3,18 +3,29 @@ import numpy as np
import
torch
import
torch
from
dotmap
import
DotMap
from
dotmap
import
DotMap
import
pickle
import
pickle
import
os
from
pathlib
import
Path
#Based Optimizer
#Based Optimizer
def
lr_schedule
(
step
,
warmup_steps
,
anneal_steps
,
lr
,
end_lr
):
def
lr_schedule
(
step
,
warmup_steps
,
anneal_steps
,
lr
,
end_lr
,
cosine_warmup
=
False
):
warmup_percent
=
np
.
clip
(
step
,
0
,
warmup_steps
)
/
warmup_steps
warmup_percent
=
np
.
clip
(
step
,
0
,
warmup_steps
)
/
warmup_steps
anneal_percent
=
np
.
clip
(
step
-
warmup_steps
,
0
,
anneal_steps
)
/
anneal_steps
anneal_percent
=
np
.
clip
(
step
-
warmup_steps
,
0
,
anneal_steps
)
/
anneal_steps
#cosine schedule for annealing
return
lr
*
warmup_percent
-
(
lr
-
end_lr
)
*
(
1
-
np
.
cos
(
np
.
pi
*
anneal_percent
))
/
2
#kinda broken. doesn't start from 0
if
cosine_warmup
:
main_lr
=
lr
*
(
1
-
np
.
cos
(
np
.
pi
*
warmup_percent
))
/
2
else
:
main_lr
=
lr
*
warmup_percent
anneal_lr
=
(
lr
-
end_lr
)
*
(
1
-
np
.
cos
(
np
.
pi
*
anneal_percent
))
/
2
return
main_lr
-
anneal_lr
class
BasedOptimizer
:
class
BasedOptimizer
:
def
__init__
(
self
,
parameters
,
config
,
optimizer
):
def
__init__
(
self
,
parameters
,
config
,
optimizer
,
init
=
True
):
if
init
:
self
.
init_config
(
config
)
self
.
init_optimizer
(
parameters
,
optimizer
)
def
init_config
(
self
,
config
):
defaults
=
{
defaults
=
{
"lr"
:
6e-4
,
"lr"
:
6e-4
,
"end_lr"
:
6e-4
,
"end_lr"
:
6e-4
,
...
@@ -27,6 +38,9 @@ class BasedOptimizer:
...
@@ -27,6 +38,9 @@ class BasedOptimizer:
"beta1"
:
0.9
,
"beta1"
:
0.9
,
"beta2"
:
0.95
,
"beta2"
:
0.95
,
"eps"
:
1e-4
,
"eps"
:
1e-4
,
"max_lr"
:
False
,
"curr_step"
:
0
,
"curr_lr"
:
0
,
}
}
for
k
,
v
in
defaults
.
items
():
for
k
,
v
in
defaults
.
items
():
...
@@ -35,36 +49,35 @@ class BasedOptimizer:
...
@@ -35,36 +49,35 @@ class BasedOptimizer:
for
k
,
v
in
config
.
items
():
for
k
,
v
in
config
.
items
():
setattr
(
self
,
k
,
v
)
setattr
(
self
,
k
,
v
)
self
.
max_lr
=
False
def
init_optimizer
(
self
,
parameters
,
optimizer_name
):
self
.
curr_step
=
0
if
optimizer_name
==
"adamw"
:
self
.
curr_lr
=
0
self
.
optimizer
=
optim
.
AdamW
(
self
.
parameters
,
lr
=
0
,
weight_decay
=
self
.
weight_decay
,
betas
=
(
self
.
beta1
,
self
.
beta2
),
eps
=
self
.
eps
)
if
optimizer
==
"adamw"
:
self
.
optimizer
=
optim
.
AdamW
(
parameters
,
lr
=
0
,
weight_decay
=
self
.
weight_decay
,
betas
=
(
self
.
beta1
,
self
.
beta2
),
eps
=
self
.
eps
)
elif
optimizer
==
"adamw8bit"
:
elif
optimizer
_name
==
"adamw8bit"
:
import
bitsandbytes
as
bnb
import
bitsandbytes
as
bnb
self
.
optimizer
=
bnb
.
optim
.
Adam8bit
(
parameters
,
lr
=
0
,
weight_decay
=
self
.
weight_decay
,
betas
=
(
self
.
beta1
,
self
.
beta2
),
eps
=
self
.
eps
)
self
.
optimizer
=
bnb
.
optim
.
Adam8bit
(
self
.
parameters
,
lr
=
0
,
weight_decay
=
self
.
weight_decay
,
betas
=
(
self
.
beta1
,
self
.
beta2
),
eps
=
self
.
eps
)
elif
optimizer
==
"adafactor"
:
elif
optimizer
_name
==
"adafactor"
:
try
:
try
:
from
transformers.optimization
import
Adafactor
from
transformers.optimization
import
Adafactor
except
ImportError
:
except
ImportError
:
raise
ImportError
(
"Please install transformers for Adafactor"
)
raise
ImportError
(
"Please install transformers for Adafactor"
)
self
.
optimizer
=
Adafactor
(
params
=
parameters
)
self
.
optimizer
=
Adafactor
(
params
=
self
.
parameters
)
def
step
(
self
,
scaler
=
None
):
def
step
(
self
,
dry_run
=
False
,
scaler
=
None
):
if
scaler
:
if
not
dry_run
:
scaler
.
step
(
self
.
optimizer
)
if
scaler
:
scaler
.
step
(
self
.
optimizer
)
else
:
else
:
self
.
optimizer
.
step
()
self
.
optimizer
.
step
()
self
.
curr_step
=
self
.
curr_step
+
1
self
.
curr_lr
=
lr_schedule
(
self
.
curr_step
,
self
.
warmup_steps
,
self
.
anneal_steps
,
self
.
lr
,
self
.
end_lr
)
self
.
curr_lr
=
lr_schedule
(
self
.
curr_step
,
self
.
warmup_steps
,
self
.
anneal_steps
,
self
.
lr
,
self
.
end_lr
)
self
.
curr_step
=
self
.
curr_step
+
1
if
not
self
.
max_lr
:
if
not
self
.
max_lr
:
if
self
.
curr_lr
==
self
.
end_lr
:
if
self
.
curr_lr
==
self
.
end_lr
:
print
(
"max lr reached."
)
print
(
"max lr reached."
)
...
@@ -85,15 +98,21 @@ class BasedOptimizer:
...
@@ -85,15 +98,21 @@ class BasedOptimizer:
if
self
.
curr_step
!=
0
:
if
self
.
curr_step
!=
0
:
print
(
f
"curr_lr: {str(self.get_current_lr())}"
)
print
(
f
"curr_lr: {str(self.get_current_lr())}"
)
def
save
(
self
,
path
):
def
save
(
self
,
path
:
Path
):
torch
.
save
(
self
.
optimizer
.
state_dict
(),
path
)
path
=
path
/
"opt"
with
open
(
path
,
'wb'
)
as
f
:
path
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
pickle
.
dump
(
self
,
f
)
torch
.
save
(
self
.
optimizer
.
state_dict
(),
path
/
"opt_states.pt"
)
del
self
.
optimizer
metadata
=
self
.
__dict__
with
open
(
path
/
"opt_metadata.pkl"
,
'wb'
)
as
f
:
pickle
.
dump
(
metadata
,
f
)
@
classmethod
@
classmethod
def
load
(
cls
,
path
):
def
load
(
cls
,
parameters
,
path
):
with
open
(
path
,
'rb'
)
as
f
:
path
=
path
/
"opt"
based_optimizer
=
pickle
.
load
(
f
)
with
open
(
path
/
"opt_metadata.pkl"
,
'rb'
)
as
f
:
metadata
=
pickle
.
load
(
f
)
based_optimizer
.
optimizer
.
load_state_dict
(
torch
.
load
(
path
))
based_optimizer
=
cls
(
parameters
,
metadata
,
metadata
[
"optimizer_name"
])
based_optimizer
.
optimizer
.
load_state_dict
(
torch
.
load
(
path
/
"opt_states.pt"
))
return
based_optimizer
return
based_optimizer
\ No newline at end of file
hypertrain.py
View file @
971ed5dc
...
@@ -146,7 +146,6 @@ class HyperNetworkSingle(nn.Module):
...
@@ -146,7 +146,6 @@ class HyperNetworkSingle(nn.Module):
return
x
.
bfloat16
()
return
x
.
bfloat16
()
model_config
=
{
model_config
=
{
"model_class"
:
"n_layer"
:
28
,
"n_layer"
:
28
,
"n_head"
:
16
,
"n_head"
:
16
,
"hidden_dim"
:
4096
,
"hidden_dim"
:
4096
,
...
...
scripts/test_optimizer.py
0 → 100644
View file @
971ed5dc
from
basedformer
import
optimizer
import
torch
from
tqdm
import
tqdm
import
wandb
import
os
from
pathlib
import
Path
train_config
=
{
"lr"
:
5e-4
,
"end_lr"
:
1e-4
,
"warmup_steps"
:
100
,
"anneal_steps"
:
90
,
}
model
=
torch
.
nn
.
Linear
(
10
,
100
)
save_folder
=
"models/test_optimizer2"
if
not
os
.
path
.
isdir
(
save_folder
+
"/opt"
):
opt
=
optimizer
.
BasedOptimizer
(
model
.
parameters
(),
train_config
,
"adamw"
)
else
:
opt
=
optimizer
.
BasedOptimizer
.
load
(
model
.
parameters
(),
Path
(
save_folder
))
wandb
.
init
(
project
=
"opt-test"
,
name
=
"test"
)
for
x
in
tqdm
(
range
(
opt
.
curr_step
,
100
)):
print
(
f
"Step {opt.curr_step}: LR {opt.curr_lr}"
)
wandb
.
log
({
"lr"
:
opt
.
curr_lr
})
opt
.
step
(
dry_run
=
True
)
#if x == 60:
#opt.save(Path(save_folder))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment